Source code for ilex.fitting

##===============================================##
##===============================================##
## Author: Tyson Dial
## Email: tdial@swin.edu.au
## Last Updated: 25/09/2023 
##
##
## 
## 
## Stats functions for analysing FRBs.
##
##
##===============================================##
##===============================================##
# imports
#TODO: add import guards?
import numpy as np
from scipy.optimize import curve_fit
from scipy.signal import correlate
import bilby
from copy import deepcopy
import sys, inspect
from .logging import log
from math import ceil, floor
import matplotlib.pyplot as plt


# rm synthesis
from RMtools_1D.do_RMsynth_1D import run_rmsynth
from RMtools_1D.do_RMclean_1D import run_rmclean

## import utils ##
from .utils import struct_

from .data import *

from .pyfit import fit, _clean_bilby_run, _priorUniform

from .globals import *



##===============================================##
##             utility functions                 ##
##===============================================##


[docs] def model_curve(y, n: int = 5, samp: int = None): """ Fit Polynomial model to data Parameters ---------- y : np.ndarray data to model n : int, optional Polynomial order, by default 5 samp : int, optional number of samples to sample modelled data, by default None Returns ------- np.ndarray Modelled data """ x = np.linspace(0, 1.0, y.size) if samp is None: samp = x.size xnew = np.linspace(0, 1.0, samp) y_fit = np.poly1d(np.polyfit(x,y,n)) return y_fit(xnew)
[docs] def make_polyfit(n): mstr = "lambda x" mstr2 = "" for i in range(n): mstr += f", x{i}" mstr2 += f"x{i} * x**{n-i} + " mstr += ", c : " mstr2 += "+ c" return eval(mstr + mstr2)
##===============================================## ## basic fitting functions ## ##===============================================## ## [ LORENTZIAN FUNCTION ] ##
[docs] def lorentz(x,w,a): """ Lorentz function - Usually used to model scintillation bandwidth Parameters ---------- x : np.ndarray X data - Usually Frequency array w : float width - Usually Scintillation Bandwidth a : float Amplitude - Usually m^2 where m is modulation index Returns ------- np.ndarray Y data """ return a*w**2/(w**2+x**2)
## [ GAUSSIAN FUNCTION ] ##
[docs] def gaussian(x, a, mu, sig): """ Gaussian Pulse Function Parameters ---------- x : np.ndarray X data a : float amplitude mu : float position of Gaussian Pulse sig : float width of Gaussian Pulse Returns ------- np.ndarray Y data """ return a*np.exp(-(x-mu)**2/(2*sig**2))
[docs] def scat(dx, tau, sig = 10): """ 1 sided (positive side) exponential Scattering tail function Parameters ---------- x : np.ndarray X data tau : float Scattering Timescale sig: float number of standard deviations for defined scat function from mean Returns ------- np.ndarray Y data """ # create x with same time resolution _w = int(ceil(tau*sig/dx)) x = np.linspace(-_w*dx, _w*dx, 2*_w+1) hw = x.size//2 yscat = np.zeros(x.size) # only fill one side, since this is a one sided # exponential yscat[hw:] = np.exp(-x[hw:]/(tau)) return yscat
[docs] def specindex(x, a, alpha): """ Spectral index power-law function Parameters ---------- x : np.ndarray X data a : float Amplitude alpha : float Power-law index Returns ------- np.ndarray Y data """ # spectral index function return a*x**alpha
##===============================================## ## Advanced fitting functions ## ##===============================================##
[docs] def scatt_pulse_profile(x, p): """ Scattering time series profile with n pulses. Numerical convolution if done incorrectly can shift the resultant data in an undesirable way. One way to avoid this is to take a large window around the known signal to encompass the all pulses and convolve this with a symmetrical scattering tail. This of course isn't realistic when taking a crop of data whose bounds cut through potential signal. To keep this function robust, the algorithm implemented here takes each gaussian profile and extends it until symmetrical, this avoids any potential shifting due to improper convolution. Parameters ---------- x: np.ndarray X data array p: Dict(float) dictionary of parameters for scattered Gaussian pulses, for each pulse n: \n [a[n]] - Pulse amplitude \n [mu[n]] - Pulse position \n [sig[n]] - Pulse width \n [tau] - scattering timescale Returns ------- y: np.ndarray Y data array """ # create empty output array y = np.zeros(x.size) # create scattering tail dt = x[1] - x[0] npulses = (len(p) - 1)//3 stail = scat(dt,p['tau']) # Each gaussian will be isolated and convolved seperatley with enough padding for a complete # uniform convolution with zero shifting due to numerical error. for i in range(npulses): # make gaussian with sigma 5 xe = int(floor((p[f"mu{i+1}"] + p[f"sig{i+1}"]*5)/dt)) xs = int(floor((p[f"mu{i+1}"] - p[f"sig{i+1}"]*5)/dt)) # make sure the scattering tail is smaller or equal to the size of the gaussian to convolve if xe-xs + 1 < stail.size: # expand to same size as stail lendif = int(ceil((stail.size - (xe-xs+1))/2)) xe += lendif xs -= lendif x_i = np.linspace(xs*dt, xe*dt, xe-xs + 1) # crop bounded signal ps = int(floor(x[0]/dt)) xs -= ps xe -= ps # handle edge cases if xs >= x.size: continue if xe <= 0: continue # make pulse pulse_i = gaussian(x_i, 1, p[f"mu{i+1}"], p[f"sig{i+1}"]) # convolve conv = np.convolve(pulse_i, stail, mode = "same") pulse_ind = [0, conv.size] if xs < 0: pulse_ind[0] = 0 - xs xs = 0 if xe + 1 > x.size: pulse_ind[1] -= (xe+1 - x.size) xe = x.size y[xs:xe+1] += p[f"a{i+1}"] * conv[pulse_ind[0]:pulse_ind[1]]/np.max(conv) return y
# def scatt_pulse_profile(x, p): # """ # Scattering time series profile with n pulses. Numerical convolution if done incorrectly # can shift the resultant data in an undesirable way. One way to avoid this is to take a large window # around the known signal to encompass the all pulses and convolve this with a symmetrical scattering tail. # This of course isn't realistic when taking a crop of data whose bounds cut through potential signal. # To keep this function robust, the algorithm implemented here takes each gaussian profile and extends it until # symmetrical, this avoids any potential shifting due to improper convolution. # Parameters # ---------- # x: np.ndarray # X data array # p: Dict(float) # dictionary of parameters for scattered Gaussian pulses, for each pulse n: \n # [a[n]] - Pulse amplitude \n # [mu[n]] - Pulse position \n # [sig[n]] - Pulse width \n # [tau] - scattering timescale # Returns # ------- # y: np.ndarray # Y data array # """ # # create empty output array # y = np.zeros(x.size) # print(x) # # create scattering tail # dt = x[1] - x[0] # npulses = (len(p) - 1)//3 # stail = scat(dt,p['tau'], sig = 3) # print(stail.size) # # plt.plot(np.linspace(0.0, 1.0, stail.size), stail) # # Each gaussian will be isolated and convolved seperatley with enough padding for a complete # # uniform convolution with zero shifting due to numerical error. # for i in range(npulses): # # make gaussian with sigma 5 # xe = int(floor((p[f"mu{i+1}"] + p[f"sig{i+1}"]*5)/dt)) #assuming starts at zero, at increments of # xs = int(floor((p[f"mu{i+1}"] - p[f"sig{i+1}"]*5)/dt)) # # make sure the scattering tail is smaller or equal to the size of the gaussian to convolve # # if x.size < stail.size: # # # expand to same size as stail # # lendif = (stail.size - (x.size)) # # xe += lendif # # xs -= lendif # x_i = np.linspace(xs*dt, xe*dt, xe-xs + 1) # print(x_i) # # crop bounded signal # ps = int(floor(x[0]/dt)) # xs -= ps # xe -= ps # # handle edge cases # if xs >= x.size: # continue # if xe <= 0: # continue # # make pulse # pulse_i = gaussian(x, 1, p[f"mu{i+1}"], p[f"sig{i+1}"]) # print(pulse_i.size) # print(x.size) # print(stail.size) # # print(pulse_i.size) # # plt.plot(x_i, pulse_i) # # convolve # conv = np.convolve(pulse_i, stail, mode = "same") # # plt.plot(np.linspace(0,1.0, stail.size),stail) # plt.plot(x, conv/np.max(conv)*p[f"a{i+1}"], label = f"{1/dt}") # pulse_ind = [0, conv.size] # if xs < 0: # pulse_ind[0] = 0 - xs # xs = 0 # if xe + 1 > x.size: # pulse_ind[1] -= (xe+1 - x.size) # xe = x.size # # y[xs:xe+1] += p[f"a{i+1}"] * conv[pulse_ind[0]:pulse_ind[1]]/np.max(conv) # return y
[docs] def make_scatt_pulse_profile_func(n = 1): """ Make scatter pulse profile wrapping function for fitting Parameters ---------- n: int number of pulses in scatter profile Returns ------- func: __func__ lambda function for scatter pulse profile with n pulses """ args_str = "lambda x" func_str = "scatt_pulse_profile(x,{" for i in range(1, n+1): # loop through components for p in ["a", "mu", "sig"]: func_str += f"'{p}{i}':{p}{i}," args_str += f",{p}{i}" # add tau func_str += "'tau':tau" args_str += ",tau" return (eval(args_str + ":" + func_str + "})"))
##============================## ## fitting ## ##============================## ## RM fitting functions ##
[docs] def fit_RMsynth(I, Q, U, Ierr, Qerr, Uerr, f, clean_cutoff = 0.1, **kwargs): """ Use RM synthesis to calculate RM, pa0 and f0, f0 is the weighted midband frequency and pa0 the pa at f0. Parameters ---------- I: np.ndarray stokes I spectra Q: np.ndarray stokes Q spectra U: np.ndarray stokes U spectra Ierr: np.ndarray stokes I rms spectra Qerr: np.ndarray stokes Q rms spectra Uerr: np.ndarray stokes U rms spectra f: np.ndarray frequencies [MHz] clean_cutoff: float cutoff arg for run_rmclean() **kwargs: Dict keyword arguments for RM tools run_synthesis Returns ------- rm: float rotation measure rm_err: float error in rotation measure f0: float reference frequency at weighted mid-band pa0: float position angle at f0 """ defkwargs = {"polyOrd":3, "phiMax_radm2":1.0e3, "dPhi_radm2":1.0, "nSamples":100.0} ## process kwargs keys keys = kwargs.keys() for key in defkwargs.keys(): if key not in keys: kwargs[key] = defkwargs[key] log("Fitting RM using RM synthesis", lpf = False) # RM data array rmsyn_data = np.array( [ f * 1e6, # freqs (Hz) I, # I Q, # Q U, # U Ierr, # I rms Qerr, # Q rms Uerr # U rms ] ) # run RM synthesis rm_sum, rm_data = run_rmsynth(rmsyn_data, **kwargs) # apply RM cleaning rmc = run_rmclean(rm_sum, rm_data, clean_cutoff) # get estimated parameters rm = rmc[0]['phiPeakPIfit_rm2'] # RM rm_err = rmc[0]['dPhiPeakPIfit_rm2'] # RM err f0 = rm_sum['freq0_Hz'] / 1e6 # f0 (MHz) pa0 = 0.5 * np.arctan2(rmc[0]['peakFDFimagFit'],rmc[0]['peakFDFrealFit']) # pa0 (at f0) # print log(f"RM: {rm:.4f} +/- {rm_err:.4f} (rad/m2)", lpf = False) log(f"f0: {f0} (MHz)", lpf = False) log(f"pa0: {pa0} (rad)", lpf = False) return rm, rm_err, f0, pa0
[docs] def fit_RMquad(Q, U, Qerr, Uerr, f, f0, **kwargs): """ Info: Use Quadratic method to fit for RM and pa0. Parameters ---------- Q: np.ndarray stokes Q spectra U: np.ndarray stokes U spectra Qerr: np.ndarray stokes Q rms spectra Uerr: np.ndarray stokes U rms spectra f: np.ndarray frequencies [MHz] f0: float reference Frequency [MHz] **kwargs: Dict keyword arguments for RM tools run_synthesis Returns ------- rm: float rotation measure rm_err: float error in rotation measure pa0: float position angle at f0 pa0_err: float position angle err """ log("Fitting using RM quadratic function", lpf = False) # fit RM using Quadratic function def rmquad(f, rm, pa0): return pa0 + rm*c**2/1e12*(1/f**2 - 1/f0**2) PA_meas, PA_err = calc_PA(Q, U, Qerr, Uerr) PA_meas = np.unwrap(PA_meas, period = np.pi) # fit fit_val, fit_err = curve_fit(rmquad, f, PA_meas, sigma = PA_err, absolute_sigma = True, **kwargs, maxfev = 2000000) fit_err = np.sqrt(np.diag(fit_err)) # get params rm = fit_val[0] rm_err = fit_err[0] pa0 = fit_val[1] pa0_err = fit_err[1] # print log(f"RM: {rm:.4f} +/- {rm_err:.4f} (rad/m2)", lpf = False) log(f"f0: {f0} (MHz)", lpf = False) log(f"pa0: {pa0} +/- {pa0_err:.4f} (rad)", lpf = False) return rm, rm_err, pa0, pa0_err
# QUfitting likelihood class
[docs] class QUfit_likelihood(bilby.Likelihood): def __init__(self, f, Q, U, Ierr, Qerr, Uerr): """ Likelihood function for evaluating Stokes Q and U parameters to fit RM. L will be debiased and a mask will be applied to all Stokes parameters based on L debiased. Parameters ---------- f : np.ndarray Frequencies in MHz Q : np.ndarray Stokes Q parameter U : np.ndarray Stokes U parameter Ierr : np.ndarray Stokes I parameter noise, used to calculate L debiased Qerr : np.ndarray Stokes Q parameter noise Uerr : np.ndarray Stokes U parameter noise """ # constants self.c = 299_792_458 # [m/s^2] # calculate debiased L L_meas = np.sqrt(Q**2 + U**2) self.L = Ierr * np.sqrt((L_meas/Ierr)**2 - 1) self.L[L_meas/Ierr < 1.57] = np.nan self.mask = ~np.isnan(self.L) self.L = self.L[self.mask] # f data self.f = f[self.mask] * 1e6 # in Hz self.N = f.size # stk data self.Q = Q[self.mask] self.U = U[self.mask] # stk noise self.Qerr = Qerr[self.mask] self.Uerr = Uerr[self.mask] # These lines of code infer parameters from provided function parameters = inspect.getfullargspec(self.PA).args[1:] super().__init__(parameters = dict.fromkeys(parameters)) self.parameters = dict.fromkeys(parameters) self.function_keys = ["RM", "pa0"] @property def model_parameters(self): return {k: self.parameters[k] for k in self.function_keys}
[docs] def PA(self, RM, pa0): """ PA function to evaluate Parameters ---------- RM : float Rotation Measure [rad/m^2] pa0 : float Initial reference polarisation position angle [rad] """ return pa0 + RM * self.c**2 / self.f**2
[docs] def log_likelihood(self): """ Log likelihood, adding Q and U likelihoods together """ PA = self.PA(**self.model_parameters) # calculate Stokes Q log likelihood ll_Q = -0.5 * np.sum(np.log(2 * np.pi * self.Qerr**2) + ((self.Q - self.L * np.cos(2*PA))/self.Qerr)**2) # calculate Stokes U log likelihood ll_U = -0.5 * np.sum(np.log(2 * np.pi * self.Uerr**2) + ((self.U - self.L * np.sin(2*PA))/self.Uerr)**2) return ll_Q + ll_U
[docs] def RM_QUfit(Q, U, Ierr, Qerr, Uerr, f, rm_priors = [-1000, 1000], pa0_priors = [-3.1415926, 0], **kwargs): """ Fit RM using QUfit method Parameters ---------- f : np.ndarray Frequencies in MHz Q : np.ndarray Stokes Q parameter U : np.ndarray Stokes U parameter Ierr : np.ndarray Stokes I parameter noise, used to calculate L debiased Qerr : np.ndarray Stokes Q parameter noise Uerr : np.ndarray Stokes U parameter noise """ outdir = "outdir" if "outdir" in kwargs.keys(): outdir = kwargs["outdir"] label = "label" if "label" in kwargs.keys(): label = kwargs["label"] _clean_bilby_run(outdir, label) priors = {'RM':rm_priors.copy(), 'pa0':pa0_priors.copy()} # start sampling likelihood = QUfit_likelihood(f = f, Q = Q, U = U, Ierr = Ierr, Qerr = Qerr, Uerr = Uerr) result = bilby.run_sampler(likelihood = likelihood, priors = _priorUniform(priors), **kwargs) # get rm and pa0 measurements rm_posterior = result.get_one_dimensional_median_and_error_bar('RM') rm = rm_posterior.median rm_err = (abs(rm_posterior.plus) + abs(rm_posterior.minus))/2 pa0_posterior = result.get_one_dimensional_median_and_error_bar('pa0') pa0 = pa0_posterior.median pa0_err = (abs(pa0_posterior.plus) + abs(pa0_posterior.minus))/2 # plots result.plot_corner() return rm, rm_err, pa0, pa0_err