##===============================================##
##===============================================##
## Author: Tyson Dial
## Email: tdial@swin.edu.au
## Last Updated: 25/09/2023
##
##
##
##
## Library of functions to plot data
##
##
##
##===============================================##
##===============================================##
# imports
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import os
import cmasher as cmr
from .frbutils import get_dynspec_plot_properties
## import stats functions ##
from .fitting import (lorentz, scatt_pulse_profile, scat,
gaussian, model_curve)
from .globals import c, _G
from .data import *
from .utils import load_plotstyle, fix_ds_freq_lims
# constants
default_col = plt.rcParams['axes.prop_cycle'].by_key()['color']
#--------------------#
# to set up as global set func
#--------------------#
ILEX_PLOT_FONTSIZE = 16
ILEX_PLOT_ERRTYPE = "lines"
#-------------------------------------------------#
# UTILITY FUNCTIONS FOR PLOTTING #
#-------------------------------------------------#
def _data_from_dict(dic, keys):
"""
Get data from dictionary
"""
# check if data is there
dic_keys = dic.keys()
out_dic = {}
# put data into output dictionary
for key in keys:
if key not in dic_keys:
print("not all data given")
return (None, ) * 2
out_dic[key] = dic[key].copy()
for key in ["freq", "time"]:
if key in dic_keys:
out_dic[key] = dic[key].copy()
# now check if error data has been given
err_flag = True
err_keys = []
for key in keys:
if (key not in ["freq", "time"]) and ("err" not in key):
err_keys += [f"{key}err"]
if f"{key}err" in dic_keys:
if dic[f"{key}err"] is not None:
out_dic[f"{key}err"] = dic[f"{key}err"].copy()
else:
err_flag = False
else:
err_flag = False
# if false, set all errors to false for convenience
if not err_flag:
for key in err_keys:
out_dic[key] = None
return out_dic, err_flag
# wrapper for better function definition
[docs]
def plot(x, y, yerr = None, ax = None, plot_type = "lines", color = None, alpha = 0.5, **kwargs):
_PLOT(x = x, y = y, yerr = yerr, ax = ax, plot_type = plot_type, color = color,
alpha = alpha, **kwargs)
return
def _PLOT(x, y, yerr = None, ax = None, plot_type = "lines", color = None, alpha = 0.5,
**kwargs):
"""
General plotting function
"""
if plot_type == "scatter":
_PLOT_SCATTER(x = x, y = y, yerr = yerr, ax = ax, color = color, alpha = alpha, **kwargs)
elif plot_type == "lines":
_PLOT_LINES(x = x, y = y, yerr = yerr, ax = ax, color = color, alpha = alpha, **kwargs)
else:
print("Plot err style undefined/unsupported. ")
return
def _PLOT_SCATTER(x, y, yerr = None, ax = None, color = None, alpha = 0.5, **kwargs):
"""
Plot lines
"""
plot_pars = load_plotstyle()
for key in kwargs:
if key in _G.scatter_args:
plot_pars['scatter'][key] = kwargs[key]
continue
if key in _G.errorbar_args:
plot_pars['errorbar'][key] = kwargs[key]
# check if colors not in
for _p in ['c', 'facecolors']:
if 'c' in plot_pars['scatter'].keys():
del plot_pars['scatter']['c']
for _p in ['ecolor', 'alpha', 'markerfacecolor']:
if _p in plot_pars['errorbar'].keys():
del plot_pars['errorbar'][_p]
# plot scatter
if ax is not None:
sc = ax.scatter(x, y, c = color, facecolors = color, **plot_pars['scatter'])
if yerr is not None:
if color is None:
color = sc.get_facecolors()[0]
ax.errorbar(x = x, y = y, yerr = yerr, ecolor = color,
alpha = alpha, markerfacecolor = color,
**plot_pars['errorbar'])
else:
sc = plt.scatter(x = x, y = y, c = color, facecolors = color, **plot_pars['scatter'])
if yerr is not None:
if color is None:
color = sc.get_facecolors()[0]
plt.errorbar(x, y, yerr = yerr, ecolor = color,
alpha = alpha, markerfacecolor = color, **plot_pars['errorbar'])
return
def _PLOT_LINES(x, y, yerr = None, ax = None, color = None, alpha = 0.5, **kwargs):
"""
Plot lines
"""
plot_pars = load_plotstyle()
for key in kwargs:
if key in _G.plot_args:
plot_pars['plot'][key] = kwargs[key]
if 'color' in plot_pars['plot'].keys():
del plot_pars['plot']['color']
# plot region
if ax is not None:
ln, = ax.plot(x, y, color = color, **plot_pars['plot'])
if yerr is not None:
ax.fill_between(x, y-yerr, y+yerr, color = ln.get_color(), alpha = alpha,
edgecolor = None)
else:
ln, = plt.plot(x, y, color = color, **plot_pars['plot'])
if yerr is not None:
plt.fill_between(x, y-yerr, y+yerr, color = ln.get_color(), alpha = alpha,
edgecolor = None)
return
[docs]
def plot_dynspec(ds, ax = None, **kwargs):
"""
Plot dynamic spectrum
Parameters
----------
ds : np.ndarray
dynamic spectrum
ax : axes, optional
axes to plot dynspec, by default None
"""
ds[np.isnan(ds[:,0])] = 0
properties = get_dynspec_plot_properties()
for key in kwargs.keys():
properties[key] = kwargs[key]
# cases
if "cmap" in properties.keys():
if properties["cmap"] in plt.colormaps():
pass
elif properties["cmap"] == "arctic":
properties['cmap'] = cmr.arctic_r
else:
print("Colorbar not supported, either must be a known matplotlib colorbar or [artic] from cmasher")
del properties['cmap']
if ax is not None:
ax.imshow(ds, **properties)
else:
plt.imshow(ds, **properties)
return
[docs]
def plot_data(dat, typ = "dsI", ax = None, plot_type = "scatter"):
"""
Plot data
Parameters
----------
dat : Dict(np.ndarray)
Dictionary of stokes data, can include any data products
typ : str, optional
Type of data to plot, by default "dsI" \n
[ds] - dynamic spectra \n
[t] - time series \n
[f] - frequency spectra
ax : Axes, optional
axes handle, by default None
plot_type : str, optional
type of plotting, by default "scatter"
Returns
-------
fig : figure
Return Figure instance
"""
##==================##
## PLOT START GUARD ##
##==================##
fig_flag = True
if ax is None:
fig_flag = False
fig = plt.figure(figsize = (10, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
else:
fig = None
err_flag = False
pdat, err_flag = _data_from_dict(dat, list([typ]))
if pdat is None:
return
# check if freq array was given, else make phase array
if "freq" not in pdat.keys():
fname = "Freq (phase)"
flim = [0.0, 1.0]
else:
fname = "Freq [MHz]"
df = abs(pdat['freq'][1] - pdat['freq'][0])
flim = [np.min(pdat['freq']), np.max(pdat['freq'])] # for ds plotting
fscat_lim = flim.copy()
flim[0] -= df/2
flim[1] += df/2
# check if time array was given, else make phase array
if "time" not in pdat.keys():
tname = "Time (phase)"
tlim = [0.0, 1.0]
else:
tname = "Time [ms]"
tlim = [pdat['time'][0], pdat['time'][-1]]
dt = pdat['time'][1] - pdat['time'][0]
tscat_lim = tlim.copy()
tlim[0] -= dt/2
tlim[-1] += dt/2
# utility functions
def plot_freq(x, y):
ax.plot(x, y, 'k')
ax.set(xlabel = fname, ylabel = "Flux Density")
# check type
if typ[0:2] == "ds":
# plot dynspec
plot_dynspec(pdat[typ], ax = ax, aspect = 'auto', extent = [*tlim, *flim])
# ax.imshow(pdat[typ], aspect = 'auto', )
ax.set(xlabel = tname, ylabel = fname)
elif typ[0] == "t":
# scrunch in freq
tx = np.linspace(*tscat_lim, pdat[typ].size)
ax.set(xlabel = tname, ylabel = "Flux Density (arb.)")
_PLOT(tx, pdat[typ], pdat[f"{typ}err"], ax = ax, color = 'k', alpha = 0.5,
plot_type = plot_type)
elif typ[0] == "f":
# scrunch in time
fx = np.linspace(*fscat_lim, pdat[typ].size)[::-1]
ax.set(xlabel = fname, ylabel = "Flux Density (arb.)")
_PLOT(fx, pdat[typ], pdat[f"{typ}err"], ax = ax, color = 'k', alpha = 0.5,
plot_type = plot_type)
else:
print("Invalid data type to plot")
return fig
# def plot_RM_fit(f, pa, rm, pa0, f0, residual = False):
[docs]
def plot_RM(f, Q, U, Qerr = None, Uerr = None, rm = 0.0, pa0 = 0.0, f0 = 0.0,
ax = None, filename: str = None, plot_type = "scatter"):
"""
Plot RM fit
Parameters
----------
f : np.ndarray
Frequency array
Q : np.ndarray
Stokes Q spectrum
U : np.ndarray
Stokes U spectrum
Qerr : np.ndarray, optional
Stokes Q error spectrum, by default None
Uerr : np.ndarray, optional
Stokes U error spectrum, by default None
rm : float, optional
Rotation Measure [rad/m^2], by default 0.0
pa0 : float, optional
position angle at f0, by default 0.0
f0 : float, optional
reference frequency [MHz], by default 0.0
ax : Axes, optional
Axes handle, by default None
filename : str, optional
filename to save figure to, by default None
plot_type : str, optional
type of error to plot, by default "scatter"
Returns
-------
fig : figure
Return figure instance
"""
##==================##
## PLOT START GUARD ##
##==================##
fig_flag = True
if ax is None:
fig_flag = False
fig = plt.figure(figsize = (10, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
err_flag = (Uerr is not None and Qerr is not None)
def rmquad(f, rm, pa0):
angs = pa0 + rm*c**2/1e12*(1/f**2 - 1/f0**2)
return 0.5*np.arctan2(np.sin(2*angs), np.cos(2*angs))
# set up axes
ax.set_xlabel("Frequency [MHz]", fontsize = 12)
ax.set_ylabel("PA [deg]", fontsize = 12)
# calc PA
# PA = 0.5 * np.arctan2(U, Q)
PA, PAerr = calc_PA(Q, U, Qerr, Uerr)
PA_fit = rmquad(f, rm, pa0)
_PLOT(x = f, y = PA*180/np.pi, yerr = PAerr*180/np.pi, ax = ax, color = 'k', alpha = 0.5,
plot_type = plot_type)
# # plot
# # ax.scatter(f, PA * 180/np.pi, c = 'k', s = 5, label = "Measured PA") # PA from data
# # ax.plot(f, PA_fit * 180/np.pi, 'r', label = f"RM: {rm:.3f}, pa0: {pa0:.3f}") # PA best fit line
# # error plotting
# if err_flag:
# _, PA_err = calc_PA(Q, U, Qerr, Uerr)
# _plot_err(f, PA * 180/np.pi, PA_err * 180/np.pi, ax = ax, col = [0., 0., 0., 0.5],
# plot_type = plot_err_type)
ax.set_ylim([-90, 90])
ax.legend()
##================##
## PLOT END GUARD ##
##================##
if not fig_flag:
if filename is not None:
plt.savefig(filename)
plt.show()
return fig
return None
[docs]
def plot_PA(x, PA, PA_err, ax = None, flipPA = False,
plot_type = "scatter", **kwargs):
"""
Plot PA profile
Parameters
----------
x : np.ndarray
X data
PA : np.ndarray
Position angle
PA_err : np.ndarray
PA error
ax : Axes, optional
Axes handle, by default None
flipPA : bool, optional
plot PA over [0, 180] degrees instead of [-90, 90], by default False
plot_type : str, optional
type of error to plot, by default "scatter"
Returns
-------
fig : figure
Return figure instance
"""
##==================##
## PLOT START GUARD ##
##==================##
fig_flag = True
if ax is None:
fig_flag = False
fig = plt.figure(figsize = (10, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
else:
fig = None
# make axes
ax.set_ylabel("PA [deg]", fontsize = 12)
ax.set_xlabel("time [ms]", fontsize = 12)
# flip PA
if flipPA:
PA[PA < 0] += np.pi
if "alpha" not in kwargs.keys():
kwargs['alpha'] = 0.5
if "color" not in kwargs.keys():
kwargs['color'] = 'k'
# plot PA
# PA_mask = ~np.isnan(PA)
# ax.scatter(x, PA * 180/np.pi, c = 'k', s = 2)
_PLOT(x = x, y = PA * 180/np.pi, yerr = PA_err * 180/np.pi,
plot_type = plot_type, ax = ax, **kwargs)
paw = np.nanmax(PA * 180/np.pi) - np.nanmin(PA * 180/np.pi)
ax.set_ylim([np.nanmin((PA - PA_err) * 180/np.pi) - 0.1*paw, np.nanmax((PA + PA_err) * 180/np.pi) + 0.1*paw])
# if flipPA:
# ax.set_ylim([0, 180])
# else:
# ax.set_ylim([-90, 90])
return fig
[docs]
def plot_stokes(dat, Ldebias = False, sigma = 2.0, stk_type = "f", stk2plot = "IQUV",
stk_ratio = False, ax = None, plot_type = "scatter"):
"""
Plot Stokes data, by default stokes I, Q, U and V data is plotted
Parameters
----------
dat : Dict(np.ndarray)
Dictionary of stokes data, can include any data products but must include the following: \n
[<x>I] - Stokes I data \n
[<x>Q] - Stokes Q data \n
[<x>U] - Stokes U data \n
[<x>V] - Stokes V data \n
[<x>Ierr] - Stokes I error data, only if Ldebias = True or stk_ratio = True \n
where <x> is either 't' for time and 'f' for freq
Ldebias : bool, optional
Plot stokes L debias, by default False
sigma : float, optional
sigma threshold for error masking, I < sigma * Ierr, mask it out or
else weird overflow behavior might be present when calculating stokes ratios, by default 2.0
stk_type : str, optional
Type of stokes data to plot, "f" for Stokes Frequency data or "t" for time data, by default "f"
stk2plot : str, optional
string of stokes to plot, for example if "QV", only stokes Q and V are plotted, by default "IQUV", choice between "IQUVLP"
stk_ratio : bool, optional
if true, plot stokes ratios S/I
plot_type : str, optional
Choose between two methods of plotting the error in the data, by default "scatter" \n
[lines] - plot lines with error patches
[scatter] - Show error in data as tics in markers
Returns
-------
fig : figure
Return figure instance
"""
##==================##
## PLOT START GUARD ##
##==================##
fig_flag = True
if ax is None:
fig_flag = False
fig = plt.figure(figsize = (10, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
else:
fig = None
## check if frequency or time data
if stk_type == "t":
ax.set_xlabel("Time [ms]", fontsize = 12)
xdat = "time"
elif stk_type == "f":
ax.set_xlabel("Frequency [MHz]", fontsize = 12)
xdat = "freq"
else:
print("Invalid type")
## update ax
ax.set_ylabel("Flux Density (arb.)", fontsize = 12)
st = stk_type
data_list = [f"{st}I", f"{st}Q", f"{st}U", f"{st}V", xdat]
if Ldebias:
data_list += [f"{st}Ierr"]
# get data
pdat, err_flag = _data_from_dict(dat, data_list)
# stokes to plot
col = default_col[0:4]
col = {"I":'k', "Q":col[1], "U":col[2], "V":'b', "L": 'r', "P": 'darkviolet'}
# check if L was given in plotting
if "L" in stk2plot:
if Ldebias:
pdat[f"{st}L"], pdat[f"{st}Lerr"] = calc_Ldebiased(pdat[f"{st}Q"], pdat[f"{st}U"],
pdat[f'{st}Ierr'], pdat[f'{st}Qerr'], pdat[f'{st}Uerr'])
else:
pdat[f"{st}L"], pdat[f"{st}Lerr"] = calc_L(pdat[f'{st}Q'], pdat[f'{st}U'], pdat[f'{st}Qerr'],
pdat[f'{st}Uerr'])
# check if P was given in plotting
if "P" in stk2plot:
if Ldebias:
pdat[f"{st}P"], pdat[f"{st}Perr"] = calc_Pdebiased(pdat[f"{st}Q"], pdat[f"{st}U"], pdat[f"{st}V"],
pdat[f'{st}Ierr'], pdat[f'{st}Qerr'], pdat[f'{st}Uerr'], pdat[f'{st}Verr'])
else:
pdat[f"{st}P"], pdat[f"{st}Perr"] = calc_P(pdat[f'{st}Q'], pdat[f'{st}U'], pdat[f'{st}V'],
pdat[f'{st}Qerr'], pdat[f'{st}Uerr'], pdat[f'{st}Verr'])
# plot stokes ratios
if stk_ratio:
stk2plot = stk2plot.replace("I", "") # remove part of string
# get sigma mask
sigma_mask = pdat[f"{st}I"] < sigma * pdat[f"{st}Ierr"]
# calc ratios
for S in stk2plot:
pdat[f"{st}{S}"], pdat[f"{st}{S}err"] = calc_ratio(pdat[f"{st}I"], pdat[f"{st}{S}"],
pdat[f"{st}Ierr"], pdat[f"{st}{S}err"])
# mask values with too large errors
pdat[f"{st}{S}"][sigma_mask] = np.nan
pdat[f"{st}{S}err"][sigma_mask] = np.nan
# now we are ready to plot stokes data
for i, S in enumerate(stk2plot):
_PLOT(x = pdat[xdat], y = pdat[f"{st}{S}"], yerr = pdat[f"{st}{S}err"], ax = ax, color = col[S],
plot_type = plot_type, label = S)
ax.legend()
return fig
[docs]
def create_poincare_sphere(cbar_lims, cbar_label):
"""
Create poincare sphere plot
Parameters
----------
cbar_lims : list(float)
colorbar limits
cbar_label : str
colorbar label
Returns
-------
fig : figure
figure instance
"""
fig = plt.figure(figsize = (12,12))
ax = fig.add_subplot(111, projection = '3d')
def set_axes_equal(ax: plt.Axes):
"""Set 3D plot axes to equal scale.
Make axes of 3D plot have equal scale so that spheres appear as
spheres and cubes as cubes. Required since `ax.axis('equal')`
and `ax.set_aspect('equal')` don't work on 3D.
"""
limits = np.array([
ax.get_xlim3d(),
ax.get_ylim3d(),
ax.get_zlim3d(),
])
origin = np.mean(limits, axis=1)
radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
_set_axes_radius(ax, origin, radius)
def _set_axes_radius(ax, origin, radius):
x, y, z = origin
ax.set_xlim3d([x - radius, x + radius])
ax.set_ylim3d([y - radius, y + radius])
ax.set_zlim3d([z - radius, z + radius])
# plot sphere surface
u = np.linspace(0, 2*np.pi, 200)
v = np.linspace(0, np.pi, 200)
u, v = np.meshgrid(u, v)
x = np.sin(u) * np.cos(v)
y = np.sin(u) * np.sin(v)
z = np.cos(u)
ax.plot_surface(x,y,z, color = [0.7, 0.7, 0.7, 0.3], shade = False)
ax.plot_wireframe(np.sin(u), np.sin(u)*0, np.cos(u), color = [0.4, 0.4, 0.4, 0.5], linestyle = '--')
ax.plot_wireframe(np.sin(u)*0, np.sin(u), np.cos(u), color = [0.4, 0.4, 0.4, 0.5], linestyle = '--')
ax.plot_wireframe(np.sin(u), np.cos(u), np.cos(u)*0, color = [0.4, 0.4, 0.4, 0.5], linestyle = '--')
# plot axes
fig.tight_layout()
ax.plot([-1.0, 1.0], [0.0, 0.0], [0.0, 0.0], color = default_col[1], linestyle = '-.')
ax.plot([0.0, 0.0], [-1.0, 1.0], [0.0, 0.0], color = default_col[2], linestyle = '-.')
ax.plot([0.0, 0.0], [0.0, 0.0], [-1.0, 1.0], color = default_col[3], linestyle = '-.')
ax.text(1.2, 0, 0, "Q", fontsize = 16, color = default_col[1])
ax.text(0, 1.2, 0, "U", fontsize = 16, color = default_col[2])
ax.text(0, 0, 1.2, "V", fontsize = 16, color = default_col[3])
ax.set_xlim([-1.2, 1.2])
ax.set_xlim([-1.2, 1.2])
ax.set_xlim([-1.2, 1.2])
ax.set_box_aspect([1,1,1])
set_axes_equal(ax)
ax.dist = 7.5
ax.set_axis_off()
# create colorbar
ax_c = fig.add_axes([0.2, 0.07, 0.6, 0.02])
ax_c.get_yaxis().set_visible(False)
ax_c.imshow(np.linspace(0,1.0, 256).reshape(1, 256)[::-1], aspect = 'auto', extent = [*cbar_lims, 0.0, 1.0],
cmap = 'viridis')
ax_c.set_xlabel(cbar_label)
return fig, ax
# split into two functions, one to plot sphere, other to plot track in 3D
[docs]
def plot_poincare_track(dat, ax, sigma = 2.0, plot_data = True, plot_model = False,
normalise = True, n = 5):
"""
Plot Stokes data on a Poincare Sphere.
Parameters
----------
dat : Dict(np.ndarray)
Dictionary of stokes data, can include any data products but must include the following: \n
[I] - Stokes I data \n
[Q] - Stokes Q data \n
[U] - Stokes U data \n
[V] - Stokes V data \n
[Ierr] - Stokes I error data
stk_type : str, optional
types of stokes data to plot, by default "f" \n
[f] - Plot as a function of frequency \n
[t] - Plot as a function of time
sigma : float, optional
Error threshold used for masking stokes data in the case that stokes/I is being calculated \n
this avoids deviding by potentially small numbers and getting weird results,by default 2.0
plot_data : bool, optional
Plot Data on Poincare sphere, by default True
plot_model : bool, optional
Plot Polynomial fitted data on Poincare sphere, by default False
normalise : bool, optional
Plot data on surface of Poincare sphere (this will require normalising stokes data), by default True
n : int, optional
Maximum order of Polynomial fit, by default 5
Returns
-------
stk_i: Stokes 1D arrays
stk_m: Stokes model 1D arrays
"""
data_list = ["I", "Q", "U", "V", "Ierr"]
# get data
pdat, err_flag = _data_from_dict(dat, data_list)
# calculate stokes ratios
# choice of normalizing against stokes I or P
P = pdat['I'].copy()
Perr = pdat['Ierr'].copy()
if not err_flag:
print("stk/P requires all stokes err")
normalise = False
if normalise:
P, Perr = calc_Pdebiased(pdat['Q'], pdat['U'], pdat['V'], pdat['Ierr'],
pdat['Qerr'], pdat['Uerr'], pdat['Verr'])
stk_mask = P >= sigma * Perr
stk_i = {}
for S in "QUV":
stk_i[S] = pdat[S].copy()
stk_i[S][stk_mask] = pdat[S][stk_mask]/P[stk_mask]
stk_i[S][~stk_mask] = np.nan
stk_o = deepcopy(stk_i)
# model stokes data
if plot_model:
stk_m = {}
stk_mo = {}
for S in "QUV":
stk_m[S] = model_curve(stk_i[S], n = n, samp = 1000)
# stk_mo[S] = model_curve(stk_o[S], n = n, samp = 1000)
# plot stokes data
if plot_data:
cols = cm.viridis(np.linspace(0, 1, stk_i['Q'].size - 1))
for i in range(stk_i['Q'].size - 1):
ax.plot(stk_i['Q'][i:i+2], stk_i['U'][i:i+2], stk_i['V'][i:i+2], color = cols[i], linewidth = 3)
# plot model
if plot_model:
if plot_data:
ax.plot(stk_m['Q'], stk_m['U'], stk_m['V'], color = 'r', linestyle = '--', linewidth = 1.5)
else:
cols = cm.viridis(np.linspace(0, 1, stk_m['Q'].size - 1))
for i in range(stk_m['Q'].size - 1):
ax.plot(stk_m['Q'][i:i+2], stk_m['U'][i:i+2], stk_m['V'][i:i+2], color = cols[i], linewidth = 3)
return stk_i, stk_m