__all__ = ("TransientMetric",)
import numpy as np
from .base_metric import BaseMetric
[docs]
class TransientMetric(BaseMetric):
"""
Calculate what fraction of the transients would be detected. Best paired with a spatial slicer.
We are assuming simple light curves with no color evolution.
Parameters
----------
trans_duration : float, optional
How long the transient lasts (days). Default 10.
peak_time : float, optional
How long it takes to reach the peak magnitude (days). Default 5.
rise_slope : float, optional
Slope of the light curve before peak time (mags/day).
This should be negative since mags are backwards (magnitudes decrease towards brighter fluxes).
Default 0.
decline_slope : float, optional
Slope of the light curve after peak time (mags/day).
This should be positive since mags are backwards. Default 0.
uPeak : float, optional
Peak magnitude in u band. Default 20.
gPeak : float, optional
Peak magnitude in g band. Default 20.
rPeak : float, optional
Peak magnitude in r band. Default 20.
iPeak : float, optional
Peak magnitude in i band. Default 20.
zPeak : float, optional
Peak magnitude in z band. Default 20.
yPeak : float, optional
Peak magnitude in y band. Default 20.
survey_duration : float, optional
Length of survey (years).
Default 10.
survey_start : float, optional
MJD for the survey start date.
Default None (uses the time of the first observation).
detect_m5_plus : float, optional
An observation will be used if the light curve magnitude is brighter than m5+detect_m5_plus.
Default 0.
n_pre_peak : int, optional
Number of observations (in any filter(s)) to demand before peak_time,
before saying a transient has been detected.
Default 0.
n_per_lc : int, optional
Number of sections of the light curve that must be sampled above the detect_m5_plus theshold
(in a single filter) for the light curve to be counted.
For example, setting n_per_lc = 2 means a light curve is only considered detected if there
is at least 1 observation in the first half of the LC, and at least one in the second half of the LC.
n_per_lc = 4 means each quarter of the light curve must be detected to count.
Default 1.
n_filters : int, optional
Number of filters that need to be observed for an object to be counted as detected.
Default 1.
n_phase_check : int, optional
Sets the number of phases that should be checked.
One can imagine pathological cadences where many objects pass the detection criteria,
but would not if the observations were offset by a phase-shift.
Default 1.
count_method : {'full' 'partialLC'}, defaults to 'full'
Sets the method of counting max number of transients. if 'full', the
only full light curves that fit the survey duration are counted. If
'partialLC', then the max number of possible transients is taken to be
the integer floor
"""
def __init__(
self,
metric_name="TransientDetectMetric",
mjd_col="observationStartMJD",
m5_col="fiveSigmaDepth",
filter_col="filter",
trans_duration=10.0,
peak_time=5.0,
rise_slope=0.0,
decline_slope=0.0,
survey_duration=10.0,
survey_start=None,
detect_m5_plus=0.0,
u_peak=20,
g_peak=20,
r_peak=20,
i_peak=20,
z_peak=20,
y_peak=20,
n_pre_peak=0,
n_per_lc=1,
n_filters=1,
n_phase_check=1,
count_method="full",
**kwargs,
):
self.mjd_col = mjd_col
self.m5_col = m5_col
self.filter_col = filter_col
super(TransientMetric, self).__init__(
col=[self.mjd_col, self.m5_col, self.filter_col],
units="Fraction Detected",
metric_name=metric_name,
**kwargs,
)
self.peaks = {
"u": u_peak,
"g": g_peak,
"r": r_peak,
"i": i_peak,
"z": z_peak,
"y": y_peak,
}
self.trans_duration = trans_duration
self.peak_time = peak_time
self.rise_slope = rise_slope
self.decline_slope = decline_slope
self.survey_duration = survey_duration
self.survey_start = survey_start
self.detect_m5_plus = detect_m5_plus
self.n_pre_peak = n_pre_peak
self.n_per_lc = n_per_lc
self.n_filters = n_filters
self.n_phase_check = n_phase_check
self.count_method = count_method
[docs]
def light_curve(self, time, filters):
"""
Calculate the magnitude of the object at each time, in each filter.
Parameters
----------
time : numpy.ndarray
The times of the observations.
filters : numpy.ndarray
The filters of the observations.
Returns
-------
numpy.ndarray
The magnitudes of the object at each time, in each filter.
"""
lc_mags = np.zeros(time.size, dtype=float)
rise = np.where(time <= self.peak_time)
lc_mags[rise] += self.rise_slope * time[rise] - self.rise_slope * self.peak_time
decline = np.where(time > self.peak_time)
lc_mags[decline] += self.decline_slope * (time[decline] - self.peak_time)
for key in self.peaks:
f_match = np.where(filters == key)
lc_mags[f_match] += self.peaks[key]
return lc_mags
[docs]
def run(self, data_slice, slice_point=None):
""" "
Calculate the detectability of a transient with the specified lightcurve.
Parameters
----------
data_slice : numpy.array
Numpy structured array containing the data related to the visits provided by the slicer.
slice_point : dict, optional
Dictionary containing information about the slice_point currently active in the slicer.
Returns
-------
float
The total number of transients that could be detected.
"""
# Total number of transients that could go off back-to-back
if self.count_method == "partialLC":
_n_trans_max = np.ceil(self.survey_duration / (self.trans_duration / 365.25))
else:
_n_trans_max = np.floor(self.survey_duration / (self.trans_duration / 365.25))
tshifts = np.arange(self.n_phase_check) * self.trans_duration / float(self.n_phase_check)
n_detected = 0
n_trans_max = 0
for tshift in tshifts:
# Compute the total number of back-to-back transients are possible to detect
# given the survey duration and the transient duration.
n_trans_max += _n_trans_max
if tshift != 0:
n_trans_max -= 1
if self.survey_start is None:
survey_start = data_slice[self.mjd_col].min()
time = (data_slice[self.mjd_col] - survey_start + tshift) % self.trans_duration
# Which lightcurve does each point belong to
lc_number = np.floor((data_slice[self.mjd_col] - survey_start) / self.trans_duration)
lc_mags = self.light_curve(time, data_slice[self.filter_col])
# How many criteria needs to be passed
detect_thresh = 0
# Flag points that are above the SNR limit
detected = np.zeros(data_slice.size, dtype=int)
detected[np.where(lc_mags < data_slice[self.m5_col] + self.detect_m5_plus)] += 1
detect_thresh += 1
# If we demand points on the rise
if self.n_pre_peak > 0:
detect_thresh += 1
ord = np.argsort(data_slice[self.mjd_col])
data_slice = data_slice[ord]
detected = detected[ord]
lc_number = lc_number[ord]
time = time[ord]
ulc_number = np.unique(lc_number)
left = np.searchsorted(lc_number, ulc_number)
right = np.searchsorted(lc_number, ulc_number, side="right")
# Note here I'm using np.searchsorted to basically do a 'group by'
# might be clearer to use scipy.ndimage.measurements.find_objects or pandas, but
# this numpy function is known for being efficient.
for le, ri in zip(left, right):
# Number of points where there are a detection
good = np.where(time[le:ri] < self.peak_time)
nd = np.sum(detected[le:ri][good])
if nd >= self.n_pre_peak:
detected[le:ri] += 1
# Check if we need multiple points per light curve or multiple filters
if (self.n_per_lc > 1) | (self.n_filters > 1):
# make sure things are sorted by time
ord = np.argsort(data_slice[self.mjd_col])
data_slice = data_slice[ord]
detected = detected[ord]
lc_number = lc_number[ord]
time = time[ord]
ulc_number = np.unique(lc_number)
left = np.searchsorted(lc_number, ulc_number)
right = np.searchsorted(lc_number, ulc_number, side="right")
detect_thresh += self.n_filters
for le, ri in zip(left, right):
points = np.where(detected[le:ri] > 0)
ufilters = np.unique(data_slice[self.filter_col][le:ri][points])
phase_sections = np.floor(time[le:ri][points] / self.trans_duration * self.n_per_lc)
for filt_name in ufilters:
good = np.where(data_slice[self.filter_col][le:ri][points] == filt_name)
if np.size(np.unique(phase_sections[good])) >= self.n_per_lc:
detected[le:ri] += 1
# Find the unique number of light curves that passed the required number of conditions
n_detected += np.size(np.unique(lc_number[np.where(detected >= detect_thresh)]))
# Rather than keeping a single "detected" variable, maybe make a mask for each criteria, then
# reduce functions like: reduce_singleDetect, reduce_NDetect, reduce_PerLC, reduce_perFilter.
# The way I'm running now it would speed things up.
return float(n_detected) / n_trans_max