__all__ = (
    "HistogramMetric",
    "AccumulateMetric",
    "AccumulateCountMetric",
    "HistogramM5Metric",
    "AccumulateM5Metric",
    "AccumulateUniformityMetric",
)
import numpy as np
from scipy import stats
from .base_metric import BaseMetric
class VectorMetric(BaseMetric):
    """
    Base for metrics that return a vector
    """
    def __init__(self, bins=None, bin_col="night", col="night", units=None, metric_dtype=float, **kwargs):
        if isinstance(col, str):
            cols = [col, bin_col]
        else:
            cols = list(col) + [bin_col]
        super(VectorMetric, self).__init__(col=cols, units=units, metric_dtype=metric_dtype, **kwargs)
        self.bins = bins
        self.bin_col = bin_col
        self.shape = np.size(bins) - 1
[docs]
class HistogramMetric(VectorMetric):
    """
    A wrapper to stats.binned_statistic
    """
    def __init__(
        self,
        bins=None,
        bin_col="night",
        col="night",
        units="Count",
        statistic="count",
        metric_dtype=float,
        **kwargs,
    ):
        self.statistic = statistic
        self.col = col
        super(HistogramMetric, self).__init__(
            col=col, bins=bins, bin_col=bin_col, units=units, metric_dtype=metric_dtype, **kwargs
        )
[docs]
    def run(self, data_slice, slice_point=None):
        data_slice.sort(order=self.bin_col)
        result, bin_edges, bin_number = stats.binned_statistic(
            data_slice[self.bin_col],
            data_slice[self.col],
            bins=self.bins,
            statistic=self.statistic,
        )
        return result 
 
[docs]
class AccumulateMetric(VectorMetric):
    """
    Calculate the accumulated stat
    """
    def __init__(
        self, col="night", bins=None, bin_col="night", function=np.add, metric_dtype=float, **kwargs
    ):
        self.function = function
        super(AccumulateMetric, self).__init__(
            col=col, bin_col=bin_col, bins=bins, metric_dtype=metric_dtype, **kwargs
        )
        self.col = col
[docs]
    def run(self, data_slice, slice_point=None):
        data_slice.sort(order=self.bin_col)
        result = self.function.accumulate(data_slice[self.col])
        indices = np.searchsorted(data_slice[self.bin_col], self.bins[1:], side="right")
        indices[np.where(indices >= np.size(result))] = np.size(result) - 1
        result = result[indices]
        result[np.where(indices == 0)] = self.badval
        return result 
 
[docs]
class AccumulateCountMetric(AccumulateMetric):
[docs]
    def run(self, data_slice, slice_point=None):
        data_slice.sort(order=self.bin_col)
        to_count = np.ones(data_slice.size, dtype=int)
        result = self.function.accumulate(to_count)
        indices = np.searchsorted(data_slice[self.bin_col], self.bins[1:], side="right")
        indices[np.where(indices >= np.size(result))] = np.size(result) - 1
        result = result[indices]
        result[np.where(indices == 0)] = self.badval
        return result 
 
[docs]
class HistogramM5Metric(HistogramMetric):
    """
    Calculate the coadded depth for each bin (e.g., per night).
    """
    def __init__(
        self,
        bins=None,
        bin_col="night",
        m5_col="fiveSigmaDepth",
        units="mag",
        metric_name="HistogramM5Metric",
        **kwargs,
    ):
        super(HistogramM5Metric, self).__init__(
            col=m5_col, bin_col=bin_col, bins=bins, metric_name=metric_name, units=units, **kwargs
        )
        self.m5_col = m5_col
[docs]
    def run(self, data_slice, slice_point=None):
        data_slice.sort(order=self.bin_col)
        flux = 10.0 ** (0.8 * data_slice[self.m5_col])
        result, bin_edges, bin_number = stats.binned_statistic(
            data_slice[self.bin_col], flux, bins=self.bins, statistic="sum"
        )
        no_flux = np.where(result == 0.0)
        result = 1.25 * np.log10(result)
        result[no_flux] = self.badval
        return result 
 
[docs]
class AccumulateM5Metric(AccumulateMetric):
    def __init__(
        self, bins=None, bin_col="night", m5_col="fiveSigmaDepth", metric_name="AccumulateM5Metric", **kwargs
    ):
        self.m5_col = m5_col
        super(AccumulateM5Metric, self).__init__(
            bins=bins, bin_col=bin_col, col=m5_col, metric_name=metric_name, **kwargs
        )
[docs]
    def run(self, data_slice, slice_point=None):
        data_slice.sort(order=self.bin_col)
        flux = 10.0 ** (0.8 * data_slice[self.m5_col])
        result = np.add.accumulate(flux)
        indices = np.searchsorted(data_slice[self.bin_col], self.bins[1:], side="right")
        indices[np.where(indices >= np.size(result))] = np.size(result) - 1
        result = result[indices]
        result = 1.25 * np.log10(result)
        result[np.where(indices == 0)] = self.badval
        return result