Source code for rubin_sim.maf.maf_contrib.presto_color_kne_pop_metric

__all__ = ("PrestoColorKNePopMetric", "generate_presto_pop_slicer")

import os
import pickle
import warnings
from itertools import combinations

import numpy as np
import pandas as pd
from rubin_scheduler.data import get_data_dir
from rubin_scheduler.utils import SURVEY_START_MJD, uniform_sphere

import rubin_sim.maf.metrics as metrics
import rubin_sim.maf.slicers as slicers
from rubin_sim.phot_utils import DustValues

from .kne_metrics import KnLc


def radec2gal(ra, dec):
    """convert from ra/dec to galactic l/b"""
    from astropy import units as u
    from astropy.coordinates import SkyCoord

    c = SkyCoord(ra=ra, dec=dec, unit=(u.degree, u.degree))

    gal_l = c.galactic.l.degree
    gal_b = c.galactic.b.degree
    return gal_l, gal_b


def _load_hash(
    file_galactic="TotalCubeNorm_1000Obj.pkl",
    file_extragalactic="TotalCubeNorm_1000Obj.pkl",
    skyregion="extragalactic",
):
    """Helper function to load large hash table.

    Because this is kept outside the metric attributes, it allows easy resuse.
    Note that this does mean running different sky regions at the same
    time may end up thrashing the data in the hash/hash table.

    Parameters
    ----------
    skyregion : `str`
        The skyregion of interst.
        Only two options: 'galactic' and 'extragalactic'
    filePathGalactic : `str`
        File containing galactic Presto-Color phase space information
    filePathExtragalactic : `str`
        File containing galactic Presto-Color phase space information
    """

    if hasattr(_load_hash, "InfoDict"):
        if skyregion == _load_hash.skyregion:
            return _load_hash.InfoDict, _load_hash.HashTable

    data_dir = get_data_dir()
    if skyregion == "galactic":
        file_path = os.path.join(data_dir, "maf", file_galactic)
    elif skyregion == "extragalactic":
        file_path = os.path.join(data_dir, "maf", file_extragalactic)

    with open(file_path, "rb") as f:
        _load_hash.InfoDict = pickle.load(f)
        _load_hash.HashTable = pickle.load(f)
    _load_hash.skyregion = skyregion
    return _load_hash.InfoDict, _load_hash.HashTable


[docs] def generate_presto_pop_slicer( skyregion="galactic", t_start=1, t_end=3652, n_events=10000, seed=42, n_files=100, d_min=10, d_max=300, gb_cut=20, ): """Generate a population of KNe events, and put the info about them into a UserPointSlicer object Parameters ---------- skyregion : `str` The skyregion of interest. Only two options: 'galactic' and 'extragalactic' t_start : `float` The night to start kilonova events on (days) t_end : `float` The final night of kilonova events n_events : `int` The number of kilonova events to generate seed : `float` The seed passed to np.random n_files : `int` The number of different kilonova lightcurves to use d_min : `float` or `int` Minimum luminosity distance (Mpc) d_max : `float` or `int` Maximum luminosity distance (Mpc) Returns ------- kne_slicer : `~.maf.UserPointsSlicer` """ def rndm(a, b, g, size=1): """Power-law gen for pdf(x) proportional to x^{g-1} for a<=x<=b""" r = np.random.random(size=size) ag, bg = a**g, b**g return (ag + (bg - ag) * r) ** (1.0 / g) ra, dec = uniform_sphere(n_events, seed=seed) # Convert ra, dec to gl, gb gl, gb = radec2gal(ra, dec) # Determine if the object is in the Galaxy plane if skyregion == "galactic": # keep the galactic events ra = ra[np.abs(gb) < gb_cut] dec = dec[np.abs(gb) < gb_cut] elif skyregion == "extragalactic": # keep the extragalactic events. ra = ra[np.abs(gb) > gb_cut] dec = dec[np.abs(gb) > gb_cut] else: warnings.warn("Skyregion %s not recognized, using whole sky" % skyregion) n_events = len(ra) peak_times = np.random.uniform(low=t_start, high=t_end, size=n_events) file_indx = np.floor(np.random.uniform(low=0, high=n_files, size=n_events)).astype(int) # Define the distance distance = rndm(d_min, d_max, 4, size=n_events) # Set up the slicer to evaluate the catalog we just made slicer = slicers.UserPointsSlicer(ra, dec, lat_lon_deg=True, badval=0) # Add any additional information about each object to the slicer slicer.slice_points["peak_time"] = peak_times slicer.slice_points["file_indx"] = file_indx slicer.slice_points["distance"] = distance return slicer
[docs] class PrestoColorKNePopMetric(metrics.BaseMetric): def __init__( self, metric_name="KNePopMetric", mjd_col="observationStartMJD", m5_col="fiveSigmaDepth", filter_col="filter", night_col="night", pts_needed=2, file_list=None, mjd0=SURVEY_START_MJD, output_lc=False, skyregion="galactic", thr=0.003, **kwargs, ): """ Parameters ---------- file_list : `str` or None, optional File containing input lightcurves mjd0 : `float`, optional MJD of the start of the survey. output_lc : `bool`, optional Flag to whether or not to output lightcurve for each object. skyregion : `str`, optional The skyregion of interest. Only two options: 'galactic' and 'extragalactic' thr : `float`, optional Threshold for "classification" of events via the Score_S """ maps = ["DustMap"] self.mjd_col = mjd_col self.m5_col = m5_col self.filter_col = filter_col self.night_col = night_col # Boolean variable, if True the light curve will be exported self.output_lc = output_lc self.thr = thr self.skyregion = skyregion # read in file as light curve object; self.lightcurves = KnLc(file_list=file_list) self.mjd0 = mjd0 dust_properties = DustValues() self.ax1 = dust_properties.ax1 cols = [self.mjd_col, self.m5_col, self.filter_col, self.night_col] super().__init__(col=cols, units="Detected, 0 or 1", metric_name=metric_name, maps=maps, **kwargs) # Unused .. self.pts_needed = pts_needed def _presto_color_detect(self, around_peak, filters): """Detection criteria of presto cadence: at least three detections at two filters; Parameters ---------- around_peak : `np.ndarray`, (N,) indexes corresponding to 5sigma detections filters : `np.ndarray`, (N,) filters in which detections happened """ result = 1 if np.size(around_peak) < 3: result = 0 flts, flts_count = np.unique( filters, return_counts=True, ) if np.size(flts) < 2: result = 0 elif np.max(flts_count) < 2: # if no filters have visits larger than 2, set detection false result = 0 return result def _enquiry(self, hash_table, info_dict, band1, band2, d_t1, d_t2, d_mag, color): """ Return the value in the probability cube provided the coordinates in the Presto-color phase space of an observation triplet. Parameters ---------- hash_table : `np.ndarray`, (N,) Contains the values of the 6-D Presto-color phase space info_dict : `dict` Contains the essential information of the hash_table abobe. band1, band2 : `str`, `str` The two filters that comprise the Presto-color observation triplet. The filters are the 6 bands of LSST: u, g, r, i, z, y. Band1 and band2 should be different. d_t1, d_t2 : `float`, `float` The time gaps of the Presto-color observation triplet. d_mag : `float` The magnitude change between from the observations of the same band color : `float` The difference in magnitude of observations in different bands. hash_table and info_dict have to be loaded from premade data Presto-color data file. """ # if abs(d_t1) > abs(d_t1-d_t2): # d_t1, d_t2 = d_t1-d_t2, -d_t2 if not ( info_dict["BinMag"][0] <= d_mag < info_dict["BinMag"][-1] and info_dict["BinColor"][0] <= color < info_dict["BinColor"][-1] ): return 0 ind1 = info_dict["BandPairs"].index(band1 + band2) time_pair_grid = [ info_dict["dT1s"][abs(d_t1 - info_dict["dT1s"]).argmin()], info_dict["dT2s"][abs(d_t2 - info_dict["dT2s"]).argmin()], ] ind2 = np.where((info_dict["TimePairs"] == time_pair_grid).all(axis=1))[0][0] ind3 = np.where(d_mag >= info_dict["BinMag"])[0][-1] ind4 = np.where(color >= info_dict["BinColor"])[0][-1] return hash_table[ind1, ind2, ind3, ind4] def _get_score(self, result, hash_table, info_dict, thr): """Get the score of a strategy from the Presto-color perspective. Parameters ---------- result : `pd.DataFrame` Dataframe that contains the results of the observations. The columns include t: the time of the observation mag: the detected magnitude maglim: the limit fiveSigmaDepth that can be detected filter: the filter used for the observation hash_table : `np.ndarray`, (N,) Contains the values of the 6-D Presto-color phase space info_dict : `dict` Contains the essential information of the hash_table abobe. scoreType : `str` Two types of scores were designed: 'S' type involves a threshold, 'P' type work without a threshold. thr : `float` The threashold need for type 'S' score. The default value is 0.003 (3-sigma) hash_table and info_dict have to be loaded from the premade Presto-color data file. """ time_lim1 = 8.125 / 24 # 8 h 7.5 min time_lim2 = 32.25 / 24 # 32 h 15 min detects = result[result.mag < result.maglim] # reset index detects = detects.reset_index(drop=True) # Times for valid detections ts = detects.t.values # Find out the differences between each pair d_ts = ts.reshape(1, len(ts)) - ts.reshape(len(ts), 1) # The time differences should be within 32 hours (2 nights) d_tindex0, d_tindex1 = np.where(abs(d_ts) < time_lim2) phase_space_coords = [] # loop through the rows of the matrix of valid time differences for ii in range(d_ts.shape[0]): groups_of_three = np.array( [ [ii] + list(jj) for jj in list(combinations(d_tindex1[(d_tindex0 == ii) * (d_tindex1 > ii)], 2)) ] ) for indices in groups_of_three: bands = detects["filter"][indices].values # print('bands: ', bands) if len(np.unique(bands)) != 2: continue # The band appears once will be band2 occurence = np.array([np.count_nonzero(ii == bands) for ii in bands]) # The index of observation in band2 index2 = indices[occurence == 1][0] # The index of the first observation in band1 index11 = indices[occurence == 2][0] # The index of the second observation in band1 index12 = indices[occurence == 2][1] if ( abs(d_ts[index12, index2]) < abs(d_ts[index11, index2]) and abs(d_ts[index12, index2]) < time_lim1 ): index11, index12 = index12, index11 elif abs(d_ts[index11, index2]) > time_lim1: continue d_t1 = d_ts[index11, index2] d_t2 = d_ts[index11, index12] band1 = bands[occurence == 2][0] band2 = bands[occurence == 1][0] if band1 + band2 == "uy" or band1 + band2 == "yu": continue d_mag = (detects.mag[index11] - detects.mag[index12]) * np.sign(d_t2) color = detects.mag[index11] - detects.mag[index2] phase_space_coords.append([band1, band2, d_t1, d_t2, d_mag, color]) score_s = 0 score_p = [0] for phase_space_coord in phase_space_coords: rate = self._enquiry(hash_table, info_dict, *phase_space_coord) if score_s == 0 and rate < thr: score_s = 1 score_p.append((1 - rate)) return score_s, max(score_p)
[docs] def run(self, data_slice, slice_point=None): data_slice.sort(order=self.mjd_col) result = {} t = data_slice[self.mjd_col] - self.mjd0 - slice_point["peak_time"] mags = np.zeros(t.size, dtype=float) for filtername in np.unique(data_slice[self.filter_col]): infilt = np.where(data_slice[self.filter_col] == filtername) mags[infilt] = self.lightcurves.interp(t[infilt], filtername, lc_indx=slice_point["file_indx"]) # Apply dust extinction on the light curve a_x = self.ax1[filtername] * slice_point["ebv"] mags[infilt] += a_x distmod = 5 * np.log10(slice_point["distance"] * 1e6) - 5.0 mags[infilt] += distmod # Find the detected points around_peak = np.where((t > 0) & (t < 30) & (mags < data_slice[self.m5_col]))[0] # Filters in which the detections happened filters = data_slice[self.filter_col][around_peak] # presto color result["presto_color_detect"] = self._presto_color_detect(around_peak, filters) # Export the light curve idx = np.where(mags < 100)[0] lc = { "t": data_slice[self.mjd_col][idx], "mag": mags[idx], "maglim": data_slice[self.m5_col][idx], "filter": data_slice[self.filter_col][idx], } if self.output_lc is True: result["lc"] = lc result["slice_point"] = slice_point if result["presto_color_detect"] == 1: info_dict, hash_table = _load_hash(skyregion=self.skyregion) result["scoreS"], result["scoreP"] = self._get_score( pd.DataFrame(lc), hash_table=hash_table, info_dict=info_dict, thr=self.thr, ) else: result["scoreS"] = 0 result["scoreP"] = 0 return result
def reduce_presto_color_detect(self, metric): return metric["presto_color_detect"] def reduce_score_s(self, metric): return metric["scoreS"] def reduce_score_p(self, metric): return metric["scoreP"]