__all__ = ("apply_zp_norm", "PlotHandler", "BasePlotter")
import os
import warnings
import matplotlib.pyplot as plt
import numpy as np
import rubin_sim.maf.utils as utils
def apply_zp_norm(metric_value, plot_dict):
if "zp" in plot_dict:
if plot_dict["zp"] is not None:
metric_value = metric_value - plot_dict["zp"]
if "norm_val" in plot_dict:
if plot_dict["norm_val"] is not None:
metric_value = metric_value / plot_dict["norm_val"]
return metric_value
[docs]
class BasePlotter:
"""
Serve as the base type for MAF plotters and example of API.
"""
def __init__(self):
self.plot_type = None
# This should be included in every subsequent default_plot_dict
# (assumed to be present).
self.default_plot_dict = {
"title": None,
"xlabel": None,
"label": None,
"labelsize": None,
"fontsize": None,
"figsize": None,
}
[docs]
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
"""
Parameters
----------
metric_value : `numpy.ma.MaskedArray`
The metric values from the bundle.
slicer : `rubin_sim.maf.slicers.TwoDSlicer`
The slicer.
user_plot_dict: `dict`
Dictionary of plot parameters set by user
(overrides default values).
fig : `matplotlib.figure.Figure`
Matplotlib figure number to use. Default = None, starts new figure.
Returns
-------
fig : `matplotlib.figure.Figure`
Figure with the plot.
"""
pass
[docs]
class PlotHandler:
"""Create plots from a single or series of metric bundles.
Parameters
----------
out_dir : `str`, optional
Directory to save output plots.
results_db : `rubin_sim.maf.ResultsDb`, optional
ResultsDb into which to record plot location and information.
savefig : `bool`, optional
Flag for saving images to disk (versus create and return
to caller).
fig_format : `str`, optional
Figure format to use to save full-size output. Default PDF.
dpi : `int`, optional
DPI to save output figures to disk at. (for matplotlib figures).
thumbnail : `bool`, optional
Flag for saving thumbnails (reduced size pngs) to disk.
trim_whitespace : `bool`, optional
Flag for trimming whitespace option to matplotlib output figures.
Default True, usually doesn't need to be changed.
"""
def __init__(
self,
out_dir=".",
results_db=None,
savefig=True,
fig_format="pdf",
dpi=600,
thumbnail=True,
trim_whitespace=True,
):
self.out_dir = out_dir
self.results_db = results_db
self.savefig = savefig
self.fig_format = fig_format
self.dpi = dpi
self.trim_whitespace = trim_whitespace
self.thumbnail = thumbnail
self.filtercolors = {
"u": "cyan",
"g": "g",
"r": "y",
"i": "r",
"z": "m",
"y": "k",
" ": None,
}
self.filterorder = {" ": -1, "u": 0, "g": 1, "r": 2, "i": 3, "z": 4, "y": 5}
[docs]
def set_metric_bundles(self, m_bundles):
"""
Set the metric bundle or bundles (list or dictionary).
Reuse the PlotHandler by resetting this reference.
The metric bundles have to have the same slicer.
"""
self.m_bundles = []
# Try to add the metricBundles in filter order.
if isinstance(m_bundles, dict):
for m_b in m_bundles.values():
vals = m_b.file_root.split("_")
forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1]
forder = [o for o in forder if o is not None]
if len(forder) == 0:
forder = len(self.m_bundles)
else:
forder = forder[-1]
self.m_bundles.insert(forder, m_b)
self.slicer = self.m_bundles[0].slicer
else:
for m_b in m_bundles:
vals = m_b.file_root.split("_")
forder = [self.filterorder.get(f, None) for f in vals if len(f) == 1]
forder = [o for o in forder if o is not None]
if len(forder) == 0:
forder = len(self.m_bundles)
else:
forder = forder[-1]
self.m_bundles.insert(forder, m_b)
self.slicer = self.m_bundles[0].slicer
for m_b in self.m_bundles:
if m_b.slicer.slicer_name != self.slicer.slicer_name:
raise ValueError("MetricBundle items must have the same type of slicer")
self._combine_metric_names()
self._combine_run_names()
self._combine_metadata()
self._combine_constraints()
self.set_plot_dicts(reset=True)
[docs]
def set_plot_dicts(self, plot_dicts=None, plot_func=None, reset=False):
"""
Set or update the plot_dict for the (possibly joint) plots.
Resolution is: (from lowest to higher)
auto-generated items (colors/labels/titles)
< anything previously set in the plot_handler
< defaults set by the plotter
< explicitly set items in the metricBundle plot_dict
< explicitly set items in the plot_dicts list passed to this method.
"""
if reset:
# Have to explicitly set each dictionary to a (separate)
# blank dictionary.
self.plot_dicts = [{} for b in self.m_bundles]
if isinstance(plot_dicts, dict):
# We were passed a single dictionary, not a list.
plot_dicts = [plot_dicts] * len(self.m_bundles)
auto_label_list = self._build_legend_labels()
auto_color_list = self._build_colors()
auto_cbar = self._build_cbar_format()
auto_title = self._build_title()
if plot_func is not None:
auto_xlabel, auto_ylabel = self._build_x_ylabels(plot_func)
# Loop through each bundle and generate a plot_dict for it.
for i, bundle in enumerate(self.m_bundles):
# First use the auto-generated values.
tmp_plot_dict = {}
tmp_plot_dict["title"] = auto_title
tmp_plot_dict["label"] = auto_label_list[i]
tmp_plot_dict["color"] = auto_color_list[i]
tmp_plot_dict["cbar_format"] = auto_cbar
# Update that with anything previously set in the plot_handler.
tmp_plot_dict.update(self.plot_dicts[i])
# Then override with plot_dict items set explicitly
# based on the plot type.
if plot_func is not None:
tmp_plot_dict["xlabel"] = auto_xlabel
tmp_plot_dict["ylabel"] = auto_ylabel
# Replace auto-generated plot dict items with things
# set by the plotter_defaults, if they are not None.
plotter_defaults = plot_func.default_plot_dict
for k, v in plotter_defaults.items():
if v is not None:
tmp_plot_dict[k] = v
# Then add/override based on the bundle plot_dict parameters
# if they are set.
tmp_plot_dict.update(bundle.plot_dict)
# Finally, override with anything set explicitly by the user.
if plot_dicts is not None:
tmp_plot_dict.update(plot_dicts[i])
# And save this new dictionary back in the class.
self.plot_dicts[i] = tmp_plot_dict
# Check that the plot_dicts do not conflict.
self._check_plot_dicts()
def _combine_metric_names(self):
"""
Combine metric names.
"""
# Find the unique metric names.
self.metric_names = set()
for m_b in self.m_bundles:
self.metric_names.add(m_b.metric.name)
# Find a pleasing combination of the metric names.
order = ["u", "g", "r", "i", "z", "y"]
if len(self.metric_names) == 1:
joint_name = " ".join(self.metric_names)
else:
# Split each unique name into a list to see
# if we can merge the names.
name_lengths = [len(x.split()) for x in self.metric_names]
name_lists = [x.split() for x in self.metric_names]
# If the metric names are all the same length, see
# if we can combine any parts.
if len(set(name_lengths)) == 1:
joint_name = []
for i in range(name_lengths[0]):
tmp = set([x[i] for x in name_lists])
# Try to catch special case of filters and
# put them in order.
if tmp.intersection(order) == tmp:
filterlist = ""
for f in order:
if f in tmp:
filterlist += f
joint_name.append(filterlist)
else:
# Otherwise, just join and put into joint_name.
joint_name.append("".join(tmp))
joint_name = " ".join(joint_name)
# If the metric names are not the same length,
# just join everything.
else:
joint_name = " ".join(self.metric_names)
self.joint_metric_names = joint_name
def _combine_run_names(self):
"""
Combine runNames.
"""
self.run_names = set()
for m_b in self.m_bundles:
self.run_names.add(m_b.run_name)
self.joint_run_names = " ".join(self.run_names)
def _combine_metadata(self):
"""
Combine info_label.
"""
info_label = set()
for m_b in self.m_bundles:
info_label.add(m_b.info_label)
self.info_label = info_label
# Find a pleasing combination of the info_label.
if len(info_label) == 1:
self.joint_metadata = " ".join(info_label)
else:
order = ["u", "g", "r", "i", "z", "y"]
# See if there are any subcomponents we can combine,
# splitting on some values we expect to separate
# info_label clauses.
splitmetas = []
for m in self.info_label:
# Try to split info_label into separate phrases
# (filter / proposal / constraint..).
if " and " in m:
m = m.split(" and ")
elif ", " in m:
m = m.split(", ")
else:
m = [
m,
]
# Strip white spaces from individual elements.
m = set([im.strip() for im in m])
splitmetas.append(m)
# Look for common elements and
# separate from the general info_label.
common = set.intersection(*splitmetas)
diff = [x.difference(common) for x in splitmetas]
# Now look within the 'diff' elements and
# see if there are any common words to split off.
diffsplit = []
for d in diff:
if len(d) > 0:
m = set([x.split() for x in d][0])
else:
m = set()
diffsplit.append(m)
diffcommon = set.intersection(*diffsplit)
diffdiff = [x.difference(diffcommon) for x in diffsplit]
# If the length of any of the 'differences' is 0,
# then we should stop and not try to subdivide.
lengths = [len(x) for x in diffdiff]
if min(lengths) == 0:
# Sort them in order of length
# (so it goes 'g', 'g dithered', etc.)
tmp = []
for d in diff:
tmp.append(list(d)[0])
diff = tmp
xlengths = [len(x) for x in diff]
idx = np.argsort(xlengths)
diffdiff = [diff[i] for i in idx]
diffcommon = []
else:
# diffdiff is the part where we might expect
# our filter values to appear;
# try to put this in order.
diffdiff_ordered = []
diffdiff_end = []
for f in order:
for d in diffdiff:
if len(d) == 1:
if list(d)[0] == f:
diffdiff_ordered.append(d)
for d in diffdiff:
if d not in diffdiff_ordered:
diffdiff_end.append(d)
diffdiff = diffdiff_ordered + diffdiff_end
diffdiff = [" ".join(c) for c in diffdiff]
# And put it all back together.
combo = (
", ".join(["".join(c) for c in diffdiff])
+ " "
+ " ".join(["".join(d) for d in diffcommon])
+ " "
+ " ".join(["".join(e) for e in common])
)
self.joint_metadata = combo
def _combine_constraints(self):
"""
Combine the constraints.
"""
constraints = set()
for m_b in self.m_bundles:
if m_b.constraint is not None:
constraints.add(m_b.constraint)
self.constraints = "; ".join(constraints)
def _build_title(self):
"""
Build a plot title from the metric names, runNames and info_label.
"""
# Create a plot title from the unique parts of
# the metric/run_name/info_label.
plot_title = ""
if len(self.run_names) == 1:
plot_title += list(self.run_names)[0]
if len(self.info_label) == 1:
plot_title += " " + list(self.info_label)[0]
if len(self.metric_names) == 1:
plot_title += ": " + list(self.metric_names)[0]
if plot_title == "":
# If there were more than one of everything above,
# use joint info_label and metricNames.
plot_title = self.joint_metadata + " " + self.joint_metric_names
return plot_title
def _build_x_ylabels(self, plot_func, len_max=25):
"""
Build a plot x and y label.
Parameters
----------
len_max : `int`, optional
If the xlabel starts longer than this, add the units as a newline.
"""
if plot_func.plot_type == "BinnedData":
if len(self.m_bundles) == 1:
m_b = self.m_bundles[0]
if len(m_b.slicer.slice_col_name) < len_max:
xlabel = m_b.slicer.slice_col_name + " (" + m_b.slicer.slice_col_units + ")"
else:
xlabel = m_b.slicer.slice_col_name + " \n(" + m_b.slicer.slice_col_units + ")"
ylabel = m_b.metric.name + " (" + m_b.metric.units + ")"
else:
xlabel = set()
for m_b in self.m_bundles:
xlabel.add(m_b.slicer.slice_col_name)
xlabel = ", ".join(xlabel)
ylabel = self.joint_metric_names
elif plot_func.plot_type == "MetricVsH":
if len(self.m_bundles) == 1:
m_b = self.m_bundles[0]
ylabel = m_b.metric.name + " (" + m_b.metric.units + ")"
else:
ylabel = self.joint_metric_names
xlabel = "H (mag)"
else:
if len(self.m_bundles) == 1:
m_b = self.m_bundles[0]
xlabel = m_b.metric.name
if m_b.metric.units is not None:
if len(m_b.metric.units) > 0:
if len(xlabel) < len_max:
xlabel += " (" + m_b.metric.units + ")"
else:
xlabel += "\n(" + m_b.metric.units + ")"
ylabel = None
else:
xlabel = self.joint_metric_names
ylabel = set()
for m_b in self.m_bundles:
if "ylabel" in m_b.plot_dict:
ylabel.add(m_b.plot_dict["ylabel"])
if len(ylabel) == 1:
ylabel = list(ylabel)[0]
else:
ylabel = None
return xlabel, ylabel
def _build_legend_labels(self):
"""
Build a set of legend labels,
using the parts of the run_name/info_label/metricNames that change.
"""
if len(self.m_bundles) == 1:
return [None]
labels = []
for m_b in self.m_bundles:
if "label" in m_b.plot_dict:
label = m_b.plot_dict["label"]
else:
label = ""
if len(self.run_names) > 1:
label += m_b.run_name
if len(self.info_label) > 1:
label += " " + m_b.info_label
if len(self.metric_names) > 1:
label += " " + m_b.metric.name
labels.append(label)
return labels
def _build_colors(self):
"""
Try to set an appropriate range of colors for the metric Bundles.
"""
if len(self.m_bundles) == 1:
if "color" in self.m_bundles[0].plot_dict:
return [self.m_bundles[0].plot_dict["color"]]
else:
return ["b"]
colors = []
for m_b in self.m_bundles:
color = "b"
if "color" in m_b.plot_dict:
color = m_b.plot_dict["color"]
else:
if m_b.constraint is not None:
# If the filter is part of the sql constraint, we'll
# try to use that first.
if "filter" in m_b.constraint:
vals = m_b.constraint.split('"')
for v in vals:
if len(v) == 1:
# Guess that this is the filter value
if v in self.filtercolors:
color = self.filtercolors[v]
colors.append(color)
# If we happened to end up with the same color throughout
# (say, the metrics were all in the same filter)
# then go ahead and generate random colors.
if (len(self.m_bundles) > 1) and (len(np.unique(colors)) == 1):
colors = [
np.random.rand(
3,
)
for m_b in self.m_bundles
]
return colors
def _build_cbar_format(self):
"""
Set the color bar format.
"""
cbar_format = None
if len(self.m_bundles) == 1:
if self.m_bundles[0].metric.metric_dtype == "int":
cbar_format = "%d"
else:
metric_dtypes = set()
for m_b in self.m_bundles:
metric_dtypes.add(m_b.metric.metric_dtype)
if len(metric_dtypes) == 1:
if list(metric_dtypes)[0] == "int":
cbar_format = "%d"
return cbar_format
def _build_file_root(self, outfile_suffix=None):
"""
Build a root filename for plot outputs.
If there is only one metricBundle,
this is equal to the metricBundle fileRoot + outfile_suffix.
For multiple metricBundles,
this is created from the runNames, info_label and metric names.
If you do not wish to use the automatic filenames,
then you could set 'savefig' to False and
save the file manually to disk,
using the plot figure numbers returned by 'plot'.
"""
if len(self.m_bundles) == 1:
outfile = self.m_bundles[0].file_root
else:
outfile = "_".join([self.joint_run_names, self.joint_metric_names, self.joint_metadata])
outfile += "_" + self.m_bundles[0].slicer.slicer_name[:4].upper()
if outfile_suffix is not None:
outfile += "_" + outfile_suffix
outfile = utils.name_sanitize(outfile)
return outfile
def _build_display_dict(self):
"""
Generate a display dictionary.
This is most useful for when there are many metricBundles
being combined into a single plot.
"""
if len(self.m_bundles) == 1:
return self.m_bundles[0].display_dict
else:
display_dict = {}
group = set()
subgroup = set()
order = 0
for m_b in self.m_bundles:
group.add(m_b.display_dict["group"])
subgroup.add(m_b.display_dict["subgroup"])
if order < m_b.display_dict["order"]:
order = m_b.display_dict["order"] + 1
display_dict["order"] = order
if len(group) > 1:
display_dict["group"] = "Comparisons"
else:
display_dict["group"] = list(group)[0]
if len(subgroup) > 1:
display_dict["subgroup"] = "Comparisons"
else:
display_dict["subgroup"] = list(subgroup)[0]
display_dict["caption"] = (
"%s metric(s) calculated on a %s grid, for opsim runs %s, for info_label values of %s."
% (
self.joint_metric_names,
self.m_bundles[0].slicer.slicer_name,
self.joint_run_names,
self.joint_metadata,
)
)
return display_dict
def _check_plot_dicts(self):
"""
Check to make sure there are no conflicts in the plot_dicts
that are being used in the same subplot.
"""
# Check that the length is OK
if len(self.plot_dicts) != len(self.m_bundles):
raise ValueError(
"plot_dicts (%i) must be same length as mBundles (%i)"
% (len(self.plot_dicts), len(self.m_bundles))
)
# These are the keys that need to match (or be None)
keys2_check = ["xlim", "ylim", "color_min", "color_max", "title"]
# Identify how many subplots there are.
# If there are more than one, just don't change anything.
# This assumes that if there are more than one,
# the plot_dicts are actually all compatible.
subplots = set()
for pd in self.plot_dicts:
if "subplot" in pd:
subplots.add(pd["subplot"])
# Now check subplots are consistent.
if len(subplots) <= 1:
reset_keys = []
for key in keys2_check:
values = [pd[key] for pd in self.plot_dicts if key in pd]
if len(np.unique(values)) > 1:
# We will reset some of the keys to the default,
# but for some we should do better.
if key.endswith("Max"):
for pd in self.plot_dicts:
pd[key] = np.max(values)
elif key.endswith("Min"):
for pd in self.plot_dicts:
pd[key] = np.min(values)
elif key == "title":
title = self._build_title()
for pd in self.plot_dicts:
pd["title"] = title
else:
warnings.warn(
'Found more than one value to be set for "%s" in the plot_dicts.' % (key)
+ " Will reset to default value. (found values %s)" % values
)
reset_keys.append(key)
# Reset the most of the keys to defaults;
# this can generally be done safely.
for key in reset_keys:
for pd in self.plot_dicts:
pd[key] = None
[docs]
def plot(
self,
plot_func,
plot_dicts=None,
display_dict=None,
outfile_root=None,
outfile_suffix=None,
):
"""
Create a plot for the active metric bundles (self.set_metric_bundles).
Parameters
----------
plot_func : `rubin_sim.plots.BasePlotter`
The plotter to use to make the figure.
plot_dicts : `list` of [`dict`], optional
List of plot_dicts for each metric bundle.
Can use these to override individual metric bundle colors, etc.
display_dict : `dict`, optional
Information to save to resultsDb to accompany the figure on the
show_maf pages. Generally set automatically. Includes a caption.
outfile_root : `str`, optional
Output filename. Generally set automatically, but can be
overriden (such as when output filenames get too long).
outfile_suffix : `str`, optional
A suffix to add to the end of the default output filename.
Useful when creating a series of plots, such as for a movie.
Returns
-------
fig : `matplotlib.figure.Figure`
The plot.
"""
if not plot_func.object_plotter:
# Check that metric_values type and plotter are compatible
# (most are float/float, but some plotters expect object data ..
# and some only do sometimes).
for m_b in self.m_bundles:
if m_b.metric.metric_dtype == "object":
metric_is_color = m_b.plot_dict.get("metric_is_color", False)
if not metric_is_color:
warnings.warn("Cannot plot object metric values with this plotter.")
return
# Update x/y labels using plot_type.
self.set_plot_dicts(plot_dicts=plot_dicts, plot_func=plot_func, reset=False)
# Set outfile name.
if outfile_root is None:
outfile = self._build_file_root(outfile_suffix)
else:
outfile = outfile_root
plot_type = plot_func.plot_type
if len(self.m_bundles) > 1:
plot_type = "Combo" + plot_type
# Make plot.
fig = None
for m_b, plot_dict in zip(self.m_bundles, self.plot_dicts):
if m_b.metric_values is None:
# Skip this metricBundle.
msg = 'MetricBundle (%s) has no "metric_values".' % (m_b.file_root)
msg += " Either the values have not been calculated or they have been deleted."
warnings.warn(msg)
elif np.size(np.where(~m_b.metric_values.mask)) == 0:
msg = "MetricBundle (%s) has no unmasked metric_values, skipping plots." % (m_b.file_root)
warnings.warn(msg)
else:
fig = plot_func(m_b.metric_values, m_b.slicer, plot_dict, fig=fig)
# Add a legend if more than one metricValue is being plotted
# or if legendloc is specified.
legend_loc = None
if "legend_loc" in self.plot_dicts[0]:
legend_loc = self.plot_dicts[0]["legend_loc"]
if len(self.m_bundles) > 1:
try:
legend_loc = self.plot_dicts[0]["legend_loc"]
except KeyError:
legend_loc = "upper right"
if legend_loc is not None:
# Activate expected figure and write legend.
plt.figure(fig)
plt.legend(loc=legend_loc, fancybox=True, fontsize="smaller")
# Add the super title if provided.
if "suptitle" in self.plot_dicts[0]:
plt.suptitle(self.plot_dicts[0]["suptitle"])
# Save to disk and file info to results_db if desired.
if self.savefig:
if display_dict is None:
display_dict = self._build_display_dict()
self.save_fig(
fig,
outfile,
plot_type,
self.joint_metric_names,
self.slicer.slicer_name,
self.joint_run_names,
self.constraints,
self.joint_metadata,
display_dict,
)
return fig
def save_fig(
self,
fig,
outfile_root,
plot_type,
metric_name,
slicer_name,
run_name,
constraint,
info_label,
display_dict=None,
):
plot_file = outfile_root + "_" + plot_type + "." + self.fig_format
if fig is None:
warnings.warn(f"Trying to save figure to {plot_file} but" "figure is None. Skipping.")
return
if self.trim_whitespace:
fig.savefig(
os.path.join(self.out_dir, plot_file),
dpi=self.dpi,
bbox_inches="tight",
format=self.fig_format,
)
else:
fig.savefig(
os.path.join(self.out_dir, plot_file),
dpi=self.dpi,
format=self.fig_format,
)
# Generate a png thumbnail.
if self.thumbnail:
thumb_file = "thumb." + outfile_root + "_" + plot_type + ".png"
fig.savefig(os.path.join(self.out_dir, thumb_file), dpi=72, bbox_inches="tight")
# Save information about the file to results_db.
if self.results_db:
if display_dict is None:
display_dict = {}
metric_id = self.results_db.update_metric(
metric_name, slicer_name, run_name, constraint, info_label, None
)
self.results_db.update_display(metric_id=metric_id, display_dict=display_dict, overwrite=False)
self.results_db.update_plot(metric_id=metric_id, plot_type=plot_type, plot_file=plot_file)