##################################################
# Author: Tyson Dial #
# Email: tdial@swin.edu.au #
# Date (created): 17/03/2024 #
# Date (updated): 17/03/2024 #
##################################################
# make multi tile plot #
# #
##################################################
## imports
from ..frb import FRB
from ..data import *
from ..utils import load_param_file, dict_get, fix_ds_freq_lims
from ..plot import _PLOT, plot_PA, plot_dynspec
from ..pyfit import fit
from ..fitting import make_scatt_pulse_profile_func
import yaml
import numpy as np
import matplotlib.pyplot as plt
[docs]
def plot_master(parfile, plot_panels = "[S;D]", model = False, modelpar = None,
modelpulses = False, filename = None):
args = _empty
args.parfile = parfile
args.plot_panels = plot_panels
args.model = model
args.modelpar = modelpar
args.modelpulses = modelpulses
args.filename = filename
# get figure parameters and flags
figpar, flags = _init_figure(args)
# plot data
fig = _plot(args, figpar, flags)
return fig
[docs]
def _plot(args, figpar, flags):
"""
Plot
"""
def plot_all_pulses(ax, x, npulse, posterior):
""" Plot each pulse """
single_pulse = make_scatt_pulse_profile_func(1)
for i in range(1, npulse+1):
y = single_pulse(x, a1 = posterior[f"a{i}"], tau = posterior['tau'],
mu1 = posterior[f"mu{i}"], sig1 = posterior[f"sig{i}"])
# cut pulse at 3 sigma (roughly...)
# mask = y > 0.003*np.max(y)
ax.plot(x, y, '--', linewidth = 1.0)
# get plotting parameters
# create FRB instance
frb = FRB(args.parfile)
frb.set(show_plots = False, save_plots = False)
# get data
data_list = []
if flags['D']: # dynspec
data_list += ['dsI']
if flags['M'] or flags['R']:
data_list += ['tI']
if flags['P']:
data_list += ['tQ', 'tU']
if flags['S']:
data_list += ['tI', 'tU', 'tQ', 'tV']
data_list = list(set(data_list))
print(data_list)
data = frb.get_data(data_list, get = True)
pars = load_param_file(args.parfile)
print(data['tIerr'])
if (args.model or flags['M']) and args.modelpar is None:
if (pars['weights']['time']['func'] is not None) and (pars['weights']['time']['method'] == "func"):
NPULSES = (len(pars['weights']['time']['args'])-1) // 3
p = fit(x = data['time'], y = data['tI'], yerr = data['tIerr']*np.ones(data['tI'].size),
func = make_scatt_pulse_profile_func(NPULSES))
for key in pars['weights']['time']['args'].keys():
p.set_posterior(key, pars['weights']['time']['args'][key], 0.0, 0.0)
p._is_fit = True
else:
# run model
p = frb.fit_tscatt(method = pars['fits']['fitmethod'], npulse = pars['fits']['tscatt']['npulse'],
priors = pars['fits']['tscatt']['priors'], statics = pars['fits']['tscatt']['statics'],
fit_params = pars['fits']['tscatt']['fit_params'], redo = pars['fits']['redo'],
filename = args.filename)
NPULSES = pars['fits']['tscatt']['npulse']
elif args.modelpar is not None:
# create model
with open(args.modelpar, "r") as file:
model_par = yaml.safe_load(file)
p = fit(x = data['time'], y = data['tI'], yerr = data['tIerr']*np.ones(data['tI'].size),
func = make_scatt_pulse_profile_func(model_par['npulse']))
for key in model_par['posterior'].keys():
p.set_posterior(key, model_par['posterior'][key], 0.0, 0.0)
p._is_fit = True
NPULSES = model_par['npulse']
# create figure, make sire that figure is created after bayesian modelling, since plots are made inbetween then
fig, AX = _make_figure(figpar)
# plot dynamic spectra
if flags['D']:
ds_freq_lims = fix_ds_freq_lims(frb.this_par.f_lim, frb.this_par.df)
plot_dynspec(ds = data['dsI'], ax = AX['D'], aspect = 'auto',
extent = [*frb.this_par.t_lim, *ds_freq_lims])
AX['D'].set(ylabel = "Freq [MHz]")
# plot Stokes spectra
if flags['S']:
frb.plot_stokes(ax = AX['S'], stk_type = "t", Ldebias = pars['plots']['stk_debias'],
sigma = pars['plots']['stk_sigma'], stk_ratio = pars['plots']['stk_ratio'],
stk2plot = pars['plots']['stk2plot'])
AX['S'].set(ylabel = "Flux Density (arb.)")
# check for model
if args.model and (not flags['M']):
# get model and plot
AX['S'].plot(*p.get_model(), color = 'coral', linewidth = 2)
if args.modelpulses:
plot_all_pulses(AX['S'], p.x, NPULSES, p.get_post_val())
# plot model
if flags['M']:
_PLOT(ax = AX['M'], x = p.x, y = p.y, yerr = p.yerr,
plot_type = pars['plots']['plot_type'], color = 'k')
AX['M'].plot(*p.get_model(), color = [0.9098, 0.364, 0.3961], linewidth = 1.5)
if args.modelpulses:
plot_all_pulses(AX['M'], p.x, NPULSES, p.get_post_val())
AX['M'].set(ylabel = "Flux Density (arb.)")
# plot residuals
if flags['R']:
_PLOT(ax = AX['R'], x = p.x, y = p.y - p.get_model()[1], yerr = p.yerr,
plot_type = pars['plots']['plot_type'], color = 'k')
AX['R'].set(ylabel = "Flux Density (arb.)")
# plot Polarisation Position angle (PA)
if flags['P']:
PA, PAerr = calc_PAdebiased(dict_get(data,["tU", "tQ", "tUerr", "tQerr", "tIerr"]),
Ldebias_threshold = pars['plots']['Ldebias_threshold'])
plot_PA(data['time'], PA, PAerr, ax = AX['P'], flipPA = pars['plots']['flipPA'],
plot_type = "scatter")
# final figure adjustments
fig.tight_layout()
fig.subplots_adjust(hspace = 0)
# save file
if args.filename is not None:
plt.savefig(args.filename)
return fig