Source code for rubin_sim.maf.plots.skyproj_plotters

import abc
import warnings
from collections import defaultdict
from numbers import Integral
from weakref import WeakKeyDictionary

import astropy.units as u
import healpy as hp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
import pandas as pd
import skyproj
from astropy.coordinates import SkyCoord
from rubin_scheduler.utils import _healbin

from .plot_handler import BasePlotter, apply_zp_norm
from .spatial_plotters import set_color_lims, set_color_map


[docs] def compute_circle_points( center_ra, center_decl, radius=90.0, start_bear=0, end_bear=360, step=1, ): """Create points along a circle or arc on a sphere Parameters ---------- center_ra : `float` R.A. of the center of the circle (deg.). center_decl : `float` Decl. of the center of the circle (deg.). radius : float, optional Radius of the circle (deg.), by default 90.0 start_bear : int, optional Bearing (E. of N.) of the start of the circle (deg.), by default 0 end_bear : int, optional Bearing (E. of N.) of the end of the circle (deg.), by default 360 step : int, optional Spacing of the points along the circle (deg.), by default 1 Returns ------- circle : `pandas.DataFrame` DataFrame with points in the circle. """ ras = [] decls = [] bearing_angles = np.arange(start_bear, end_bear + step, step) * u.deg center_coords = SkyCoord(center_ra * u.deg, center_decl * u.deg) circle_coords = center_coords.directional_offset_by(bearing_angles, radius * u.deg) bearings = bearing_angles.value.tolist() ras = circle_coords.ra.deg.tolist() decls = circle_coords.dec.deg.tolist() # de-normalize RA so there are no large jumps, which can # confuse cartopy + matplotlib previous_ra = ras[0] for ra_index in range(1, len(ras)): if ras[ra_index] - previous_ra > 180: ras[ra_index] -= 360 if ras[ra_index] - previous_ra < -180: ras[ra_index] += 360 previous_ra = ras[ra_index] circle = pd.DataFrame( { "bearing": bearings, "ra": ras, "decl": decls, } ).set_index("bearing") return circle
[docs] class SkyprojPlotter(BasePlotter, abc.ABC): """Base class for all plotters that use the skyproj module.""" # Store a mapping between figures and existing SkyprojPlotters, but use # weak keys so that the figure can be cleaned up by the garbage collector, # and when it is entry in the dictionary automatically deleted. skyproj_instances = WeakKeyDictionary() default_decorations = ["ecliptic", "galactic_plane"] default_colorbar_kwargs = { "location": "bottom", "shrink": 0.75, "aspect": 25, "pad": 0.1, "orientation": "horizontal", } ecliptic_kwargs = {"edgecolor": "green"} galactic_plane_kwargs = {"edgecolor": "blue"} sun_kwargs = {"color": "yellow"} moon_kwargs = {"color": "orange"} horizon_kwargs = {"edgecolor": "black", "linewidth": 3} def __init__(self): super().__init__() # Customize our plotters members for our new plot self.plot_type = "Skyproj" self.default_plot_dict = {} self.default_plot_dict.update( { "subplot": 111, "skyproj": skyproj.MollweideSkyproj, "skyproj_kwargs": {"lon_0": 0}, "decorations": self.default_decorations, "title": None, "xlabel": None, "label": None, "labelsize": None, "fontsize": None, "figsize": None, } ) self._initialize_plot_dict({}) self.figure = None def _initialize_plot_dict(self, user_plot_dict): # Use a defaultdict so that everything not explicitly set return None. def return_none(): return None self.plot_dict = defaultdict(return_none) self.plot_dict.update(self.default_plot_dict) self.plot_dict.update(user_plot_dict) def _prepare_skyproj(self, fig=None): # Exctract elemens of plot_dict that need to be passed to plt.figure fig_kwargs = {k: v for k, v in self.plot_dict.items() if k in ["figsize"]} self.figure = plt.figure(**fig_kwargs) if fig is None else fig # Normalize subplot specification, # so that it can be set in either (single) integer form or tuple form. # This is needed so that the same subplot set by different ways will be # stored in the same key in the instance dictionary. subplot_in = self.plot_dict["subplot"] in_integral_form = isinstance(subplot_in, Integral) and len(str(subplot_in)) == 3 subplot = tuple(map(int, str(subplot_in))) if in_integral_form else tuple(subplot_in) # If there is an existing instance of SkyprojPlotter (or one of its # subclasses), reuse it. # Otherwise, create a new one and store it so that it may be reused. if self.figure not in self.skyproj_instances: self.skyproj_instances[self.figure] = {} if subplot in self.skyproj_instances[self.figure]: existing_instance = self.skyproj_instances[self.figure][subplot] assert isinstance(existing_instance, self.plot_dict["skyproj"]) self.skyproj = existing_instance else: # This instance of Axis will get deleted and replaced when the # corresponding SkyPlot class is created, but this instances is # needed so SkyPlot will put its instance in the correct relation # with its figure. ax = self.figure.add_subplot(*subplot) self.skyproj = self.plot_dict["skyproj"](ax=ax, **self.plot_dict["skyproj_kwargs"]) self.skyproj_instances[self.figure][subplot] = self.skyproj
[docs] def draw_circle(self, center_ra, center_decl, radius=90, **kwargs): """Draw a circle on the sphere. Parameters ---------- center_ra : `float` R.A. of the center of the circle (deg.). center_decl : `float` Decl. of the center of the circle (deg.). radius : float, optional Radius of the circle (deg.), by default 90.0 **kwargs Additional keyword arguments passed to `skyproj._Skyproj.draw_polygon`. """ points = compute_circle_points(center_ra, center_decl, radius) self.skyproj.draw_polygon(points.ra, points.decl, **kwargs)
[docs] def draw_ecliptic(self): """Draw the ecliptic the sphere. Parameters ---------- **kwargs Keyword arguments passed to `skyproj._Skyproj.draw_polygon`. """ ecliptic_pole = SkyCoord(lon=0 * u.degree, lat=90 * u.degree, frame="geocentricmeanecliptic").icrs self.draw_circle(ecliptic_pole.ra.deg, ecliptic_pole.dec.deg, **self.ecliptic_kwargs)
[docs] def draw_galactic_plane(self): """Draw the galactic plane the sphere. Parameters ---------- **kwargs Keyword arguments passed to `skyproj._Skyproj.draw_polygon`. """ galactic_pole = SkyCoord(l=0 * u.degree, b=90 * u.degree, frame="galactic").icrs self.draw_circle(galactic_pole.ra.deg, galactic_pole.dec.deg, **self.galactic_plane_kwargs)
[docs] def draw_zd(self, zd=90, **kwargs): """Draw a circle at a given zenith distance. Parameters ---------- zd : `float` The zenith distance to draw, in degrees. Defaults to 90. **kwargs Keyword arguments passed to `skyproj._Skyproj.draw_polygon`. Notes ----- This method relies on the site location and time (as an mjd) set in the ``model_observatory`` element of the ``plot_dict``, which should be of the class `rubin_scheduler.scheduler.model_observatory.ModelObservatory`. """ if "model_observatory" not in self.plot_dict or self.plot_dict["model_observatory"] is None: warnings.warn("plot_dict['model_observatory'] must be set to plot a zenith distance circle") return model_observatory = self.plot_dict["model_observatory"] lmst = model_observatory.return_conditions().lmst * 360 / 24 latitude = model_observatory.location.lat.deg self.draw_circle(lmst, latitude, zd, **kwargs)
[docs] def draw_body(self, body="sun", **kwargs): """Mark the sun or moon. Parameters ---------- body : `str` The name of the body to draw, either `sun` or `moon`. Defaults to `sun` **kwargs Keyword arguments passed to `skyproj._Skyproj.scatter` Notes ----- This method relies on the site location and time (as an mjd) set in the ``model_observatory`` element of the ``plot_dict``, which should be of the class `rubin_scheduler.scheduler.model_observatory.ModelObservatory`. """ if "model_observatory" not in self.plot_dict or self.plot_dict["model_observatory"] is None: warnings.warn(f"plot_dict['model_observatory'] must be set to plot the {body}") return mjd = self.plot_dict["model_observatory"].mjd model_observatory = self.plot_dict["model_observatory"] sun_moon_positions = model_observatory.almanac.get_sun_moon_positions(mjd) ra = np.degrees(sun_moon_positions[f"{body}_RA"].item()) decl = np.degrees(sun_moon_positions[f"{body}_dec"].item()) self.skyproj.scatter(ra, decl, **kwargs)
[docs] def decorate(self): """Add decorations/annotations to the sky plot. Notes ----- For decorations that depend on the time or location of the observer (e.g., the horizon and sun and moon positions) this method relies on the site location and time (as an mjd) set in the ``model_observatory`` element of the ``plot_dict``, which should be of the class `rubin_scheduler.scheduler.model_observatory.ModelObservatory`. Other decorations (e.g., the ecliptic and galactic plane) can be shown even when ``model_observatory`` is not set. """ decorations = self.plot_dict["decorations"] if "ecliptic" in decorations: self.draw_ecliptic() if "galactic_plane" in decorations: self.draw_galactic_plane() if "sun" in decorations: self.draw_body("sun", **self.sun_kwargs) if "moon" in decorations: self.draw_body("moon", **self.moon_kwargs) if "horizon" in decorations: self.draw_zd(**self.horizon_kwargs)
@abc.abstractmethod def draw(self, metric_values, slicer): raise NotImplementedError()
[docs] def __call__(self, metric_values, slicer, user_plot_dict, fig=None): """Make a plot. Parameters ---------- metric_value : `numpy.ma.MaskedArray` The metric values from the bundle. slicer : `rubin_sim.maf.slicers.TwoDSlicer` The slicer. user_plot_dict: `dict` Dictionary of plot parameters set by user (overrides default values). fig : `matplotlib.figure.Figure` Matplotlib figure. The default is ``None``, which starts new figure. Returns ------- fig : `matplotlib.figure.Figure` Figure with the plot. """ self._initialize_plot_dict(user_plot_dict) self._prepare_skyproj(fig) self.draw(metric_values, slicer) self.decorate() # Do not show axis labels unless they are set explicitly # or specified as decorators. if self.plot_dict["xlabel"] is None: if "xlabel" not in self.plot_dict["decorations"]: self.skyproj.set_xlabel("", visible=False) else: self.skyproj.set_xlabel(self.plot_dict["xlabel"]) if self.plot_dict["ylabel"] is None: if "ylabel" not in self.plot_dict["decorations"]: self.skyproj.set_ylabel("", visible=False) else: self.skyproj.set_ylabel(self.plot_dict["ylabel"]) if self.plot_dict["label"] is not None: label_kwargs = {} if "fontsize" in self.plot_dict: label_kwargs["fontsize"] = self.plot_dict["fontsize"] if "label_loc" in self.plot_dict: label_kwargs["loc"] = self.plot_dict["label_loc"] self.skyproj.ax.set_title(self.plot_dict["label"], **label_kwargs) if self.plot_dict["title"] is not None: title_kwargs = {} if "fontsize" in self.plot_dict: title_kwargs["fontsize"] = self.plot_dict["fontsize"] self.figure.suptitle(self.plot_dict["title"], **title_kwargs) return self.figure
[docs] class HpxmapPlotter(SkyprojPlotter): default_hpixmap_kwargs = {"zoom": False} def __init__(self): super().__init__() self.default_plot_dict["colorbar"] = True self.object_plotter = False
[docs] def draw_colorbar(self): """Add a color bar.""" plot_dict = self.plot_dict colorbar_kwargs = {} colorbar_kwargs.update(self.default_colorbar_kwargs) if "extend" in self.plot_dict: colorbar_kwargs["extendrect"] = self.plot_dict["extend"] == "neither" if "cbar_format" in self.plot_dict: colorbar_kwargs["format"] = self.plot_dict["cbar_format"] if "cbar_orientation" in self.plot_dict: colorbar_kwargs["orientation"] = self.plot_dict["cbar_orientation"] match self.plot_dict["cbar_orientation"].lower(): case "vertical": colorbar_kwargs["shrink"] = 0.5 colorbar_kwargs["location"] = "right" case "horizontal": colorbar_kwargs["location"] = "bottom" case _: orientation = self.plot_dict["cbar_orientation"] raise NotImplementedError(f"cbar_orintation {orientation} is not supported") colorbar = self.skyproj.draw_colorbar(**colorbar_kwargs) colorbar.set_label(self.plot_dict["xlabel"], fontsize=self.plot_dict["fontsize"]) if "labelsize" in self.plot_dict and self.plot_dict["labelsize"] is not None: colorbar.ax.tick_params(labelsize=plot_dict["labelsize"]) if "n_ticks" in self.plot_dict and self.plot_dict["n_ticks"] is not None: tick_locator = mpl.ticker.MaxNLocator(nbins=plot_dict["n_ticks"]) colorbar.locator = tick_locator colorbar.update_ticks() # If outputing to PDF, this fixes the colorbar white stripes if "cbar_edge" in self.plot_dict and self.plot_dict["cbar_edge"]: colorbar.solids.set_edgecolor("face")
[docs] def decorate(self): super().decorate() if self.plot_dict["colorbar"]: self.draw_colorbar()
def _transform_to_healpix(self, metric_value_in, slicer): # Bin the values up on a healpix grid. metric_value = _healbin( slicer.slice_points["ra"], slicer.slice_points["dec"], metric_value_in.filled(slicer.badval), nside=self.plot_dict["nside"], reduce_func=self.plot_dict["reduce_func"], fill_val=slicer.badval, ) mask = np.zeros(metric_value.size) mask[np.where(metric_value == slicer.badval)] = 1 metric_value = ma.array(metric_value, mask=mask) return metric_value def _compatible_color_limits(self, norm, clims): # Log can't be 0 compatible = (norm != "log") or (clims[0] < 0 and clims[1] < 0) or (clims[0] > 0 and clims[1] > 0) return compatible
[docs] def draw(self, metric_values_in, slicer): """Draw the healpix map. Parameters ---------- metric_value : `numpy.ma.MaskedArray` The metric values from the bundle. slicer : `rubin_sim.maf.slicers.TwoDSlicer` The slicer """ kwargs = {} kwargs.update(self.default_hpixmap_kwargs) # Check if we have a valid HEALpix slicer if "Heal" not in slicer.slicer_name: hpix_map = self._transform_to_healpix(metric_values_in, slicer) else: hpix_map = metric_values_in.copy() hpix_map = apply_zp_norm(hpix_map, self.plot_dict) if "draw_hpxmap_kwargs" in self.plot_dict: kwargs.update(self.plot_dict["draw_hpxmap_kwargs"]) try: kwargs["cmap"] = set_color_map(self.plot_dict) except AttributeError: # Probably here because an invalid cmap was set pass if self.plot_dict["log_scale"]: if "norm" in kwargs and kwargs["norm"] != "log": raise ValueError("Contradictory values for color scale normalization set.") kwargs["norm"] = "log" clims = set_color_lims(hpix_map, self.plot_dict) if "norm" not in kwargs or self._compatible_color_limits(kwargs["norm"], clims): kwargs["vmin"] = clims[0] kwargs["vmax"] = clims[1] else: warnings.warn( "Using log normalization, but color limits pass through 0. " "Adjusting so plotting doesn't fail" ) # Filling in masked values must take place after automatic color limits # might be calculated! # This masking replicates the behavior of HealpixSkyMap, but is it # what we really want to do? Why not use the default handling of # masked values applied by skyproj? if "mask_below" in self.plot_dict and self.plot_dict["mask_below"] is not None: to_mask = np.where(hpix_map <= self.plot_dict["mask_below"])[0] hpix_map.mask[to_mask] = True hpix_map = hpix_map.filled(hp.UNSEEN) else: hpix_map = hpix_map.filled(slicer.badval) self.skyproj.draw_hpxmap(hpix_map, **kwargs)
[docs] class VisitPerimeterPlotter(SkyprojPlotter): default_visits_kwargs = { "edgecolor": "black", "linewidth": 0.2, } def __init__(self): super().__init__() self.plot_type = "SkyprojVisitPerimeter" self.object_plotter = True
[docs] def draw(self, metric_values, slicer): """Draw the perimeters of visits. Parameters ---------- metric_value : `numpy.ma.MaskedArray` The metric values from the bundle. slicer : `rubin_sim.maf.slicers.TwoDSlicer` The slicer """ # If there are no visits, just return if metric_values is None or len(metric_values) == 0: return kwargs = {} kwargs.update(self.default_visits_kwargs) visits = metric_values[0] if "draw_polygon_kwargs" in self.plot_dict: kwargs.update(self.plot_dict["draw_polygon_kwargs"]) camera_perimeter_func = self.plot_dict["camera_perimeter_func"] ras, decls = camera_perimeter_func(visits["fieldRA"], visits["fieldDec"], visits["rotSkyPos"]) for visit_ras, visit_decls in zip(ras, decls): self.skyproj.draw_polygon(visit_ras, visit_decls, **kwargs)