Source code for polymath.extensions.mask_ops

#########################################################################################
# polymath/extensions/mask_ops.py: masking operations
#########################################################################################

import numpy as np
from polymath.qube import Qube


def mask_where(self, mask, replace=None, *, remask=True, recursive=True):
    """Return a copy of this object after a mask has been applied.

    If the mask is empty, this object is returned unchanged.

    Parameters:
        mask (array-like): The mask to apply as a boolean array.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.
        recursive (bool, optional): True to mask the derivatives as well;
            False to leave them unmasked.

    Returns:
        Qube: A copy of this object with the mask applied.

    Raises:
        ValueError: If the replacement shape is incompatible with the object shape.
    """

    if replace is None and not remask:      # nothing to do
        return self

    # Convert to boolean array if necessary
    mask = Qube._suitable_mask(mask, self._shape)

    # If the mask is empty, return the object as is
    if not np.any(mask):
        return self

    # Get the replacement value as this type
    if replace is not None:
        replace = self.as_this_type(replace, recursive=True)
        if replace._shape not in ((), self._shape):
            raise ValueError(f'{type(self).__name__}.mask_where() replacement has '
                             f'incompatible shape: {replace._shape}, {self._shape}')

    # Shapeless case
    if self._is_scalar:
        if replace is None:
            obj = self.copy(recursive=True)
        else:
            obj = replace.copy(recursive=True)

        if remask:
            obj = obj.remask(True, recursive=recursive)

        return obj

    # Case with no replacement
    if replace is None:
        # Note that the new mask must be a copy
        obj = self.remask_or(mask, recursive=True)
        return obj

    # If replacement is an array or single Qube...

    # We need a mask to apply to the given replacement value.
    # If the replacement value has shape, use the existing mask; otherwise,
    # use True, which will allow the replacement to broadcast as needed.
    rep_mask = mask if replace._shape else True

    obj = self.copy()
    obj[mask] = replace[rep_mask]   # handles derivatives too!

    if remask:
        obj = obj.remask_or(mask, recursive=recursive)

    return obj


def mask_where_eq(self, match, replace=None, *, remask=True):
    """Return a copy of this object with items equal to a value masked.

    Instead of or in addition to masking the items, the values can be
    replaced. If no items need to be masked, this object is returned
    unchanged.

    Parameters:
        match (object): The item value to match.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with matching items masked.
    """

    match = self.as_this_type(match, recursive=False)

    axes = tuple(range(-self._rank, 0))
    mask = np.all(self._values == match._values, axis=axes)

    return self.mask_where(mask, replace=replace, remask=remask)


def mask_where_ne(self, match, replace=None, *, remask=True):
    """Return a copy of this object with items not equal to a value masked.

    Instead of or in addition to masking the items, the values can be replaced.
    If no items need to be masked, this object is returned unchanged.

    Parameters:
        match (object): The item value to match.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with non-matching items masked.
    """

    match = self.as_this_type(match, recursive=False)

    axes = tuple(range(-self._rank, 0))
    mask = np.all(self._values != match._values, axis=axes)

    return self.mask_where(mask, replace=replace, remask=remask)


def mask_where_le(self, limit, replace=None, *, remask=True):
    """Return a copy of this object with items <= a limit value masked.

    Instead of or in addition to masking the items, the values can be replaced.
    If no items need to be masked, this object is returned unchanged.

    Parameters:
        limit (object): The limiting value.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with items <= limit masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_le()')
    self._require_scalar('mask_where_le()')
    limit = _limit_from_qube(self, limit, masked=-np.inf, op='mask_where_le()')
    return self.mask_where(self._values <= limit, replace=replace, remask=remask)


def mask_where_ge(self, limit, replace=None, *, remask=True):
    """Return a copy of this object with items >= a limit value masked.

    Instead of or in addition to masking the items, the values can be replaced.
    If no items need to be masked, this object is returned unchanged.

    Parameters:
        limit (object): The limiting value.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with items >= limit masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_ge()')
    self._require_scalar('mask_where_ge()')
    limit = _limit_from_qube(self, limit, masked=np.inf, op='mask_where_ge()')
    return self.mask_where(self._values >= limit, replace=replace, remask=remask)


def mask_where_lt(self, limit, replace=None, *, remask=True):
    """Return a copy with items less than a limit value masked.

    Instead of or in addition to masking the items, the values can be
    replaced. If no items need to be masked, this object is returned
    unchanged.

    Parameters:
        limit (object): The limiting value.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with items < limit masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_lt()')
    self._require_scalar('mask_where_lt()')
    limit = _limit_from_qube(self, limit, masked=-np.inf, op='mask_where_lt()')
    return self.mask_where(self._values < limit, replace=replace, remask=remask)


def mask_where_gt(self, limit, replace=None, *, remask=True):
    """Return a copy with items greater than a limit value masked.

    Instead of or in addition to masking the items, the values can be replaced.
    If no items need to be masked, this object is returned unchanged.

    Parameters:
        limit (object): The limiting value.
        replace (object, optional): A single replacement value or an object of the same
            shape and class as this object, containing replacement values. These are
            inserted into returned object at every masked location. Use None to leave
            values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy of this object with items > limit masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_gt()')
    self._require_scalar('mask_where_gt()')
    limit = _limit_from_qube(self, limit, masked=np.inf, op='mask_where_gt()')
    return self.mask_where(self._values > limit, replace=replace, remask=remask)


def mask_where_between(self, lower, upper, *, mask_endpoints=False, replace=None,
                       remask=True):
    """Return a copy with values between two limits masked.

    Instead of or in addition to masking the items, the values can be replaced. If no
    items need to be masked, this object is returned unchanged.

    Parameters:
        lower (Qube, array-like, float, or int): The lower limit as a number or an object
            that can be broadcasted to the shape of this object's values (including its
            item shape). Masked limits are ignored.
        upper (Qube, array-like, float, or int): The upper limit as a number or an object
            that can be broadcasted to the shape of this object's values (including its
            item shape). Masked limits are ignored.
        mask_endpoints (bool or tuple, optional): True to mask the endpoints, where values
            are equal to the lower or upper limits; False to exclude the endpoints. Use a
            tuple of two values to handle the endpoints differently.
        replace (Qube, array-like, float, or int, optional): A single replacement value or
            an object that can be broadcasted to the shape of this object's values
            (including its item shape). Masked replacements become masked. Use None to
            leave values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy with values between the specified limits masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_between()')
    self._require_scalar('mask_where_between()')
    lower = _limit_from_qube(self, lower, masked=+np.inf, op='mask_where_between()')
    upper = _limit_from_qube(self, upper, masked=-np.inf, op='mask_where_between()')

    # To minimize the number of array operations, identify the options first
    if not isinstance(mask_endpoints, (tuple, list)):
        mask_endpoints = (mask_endpoints, mask_endpoints)

    if mask_endpoints[0]:               # lower point included in the mask
        op0 = self._values.__ge__
    else:                               # lower point excluded from the mask
        op0 = self._values.__gt__

    if mask_endpoints[1]:               # upper point included in the mask
        op1 = self._values.__le__
    else:                               # upper point excluded from the mask
        op1 = self._values.__lt__

    mask = op0(lower) & op1(upper)

    return self.mask_where(mask, replace=replace, remask=remask)


def mask_where_outside(self, lower, upper, *, mask_endpoints=False, replace=None,
                       remask=True):
    """Return a copy with values outside two limits masked.

    Instead of or in addition to masking the items, the values can be replaced. If no
    items need to be masked, this object is returned unchanged.

    Parameters:
        lower (Qube, array-like, float, or int): The lower limit as a number or an object
            that can be broadcasted to the shape of this object's values (including its
            item shape). Masked limits are ignored.
        upper (Qube, array-like, float, or int): The upper limit as a number or an object
            that can be broadcasted to the shape of this object's values (including its
            item shape). Masked limits are ignored.
        mask_endpoints (bool or tuple, optional): True to mask the endpoints, where values
            are equal to the lower or upper limits; False to exclude the endpoints. Use a
            tuple of two values to handle the endpoints differently.
        replace (Qube, array-like, float, or int, optional): A single replacement value or
            an object that can be broadcasted to the shape of this object's values
            (including its item shape). Masked replacements become masked. Use None to
            leave values unchanged.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.

    Returns:
        Qube: A copy with values outside the specified limits masked.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('mask_where_outside()')
    self._require_scalar('mask_where_outside()')
    lower = _limit_from_qube(self, lower, masked=-np.inf, op='mask_where_outside()')
    upper = _limit_from_qube(self, upper, masked=+np.inf, op='mask_where_outside()')

    # To minimize the number of array operations, identify the options first
    if not isinstance(mask_endpoints, (tuple, list)):
        mask_endpoints = (mask_endpoints, mask_endpoints)

    if mask_endpoints[0]:               # end points are included in the mask
        op0 = self._values.__le__
    else:                               # end points are excluded from the mask
        op0 = self._values.__lt__

    if mask_endpoints[1]:               # end points are included in the mask
        op1 = self._values.__ge__
    else:                               # end points are excluded from the mask
        op1 = self._values.__gt__

    mask = op0(lower) | op1(upper)

    return self.mask_where(mask, replace=replace, remask=remask)


def clip(self, lower, upper, *, remask=True, inclusive=True):
    """Return a copy with values clipped to fall within a pair of limits.

    Values below the lower limit become equal to the lower limit; values above
    the upper limit become equal to the upper limit.

    Parameters:
        lower (object, optional): The numerical lower limit or an object of the same shape
            and type as this, containing lower limits. None or masked values to ignore.
        upper (object, optional): The numerical upper limit or an object of the same shape
            and type as this, containing upper limits. None or masked values to ignore.
        remask (bool, optional): True to leave the new values masked; False to replace
            the values but leave them unmasked.
        inclusive (bool, optional): True to leave values that exactly match the upper
            limit unmasked; False to mask them.

    Returns:
        Qube: A copy with values clipped to the specified limits.

    Raises:
        ValueError: If this object has denominators or item rank > 0.
    """

    self._disallow_denom('clip()')
    self._require_scalar('clip()')

    # Easy case...
    if np.isscalar(lower) and np.isscalar(upper):
        new_values = np.clip(self._values, lower, upper)
        if remask:
            outside = Qube.is_outside(self._values, lower, upper, inclusive)
            mask = Qube.or_(self._mask, outside)
        else:
            mask = self._mask

        # Without remasking, derivatives out of range are now all zero
        if self._derivs and not remask:
            new_derivs = {}
            outside = Qube.is_outside(self._values, lower, upper, inclusive)
            for key, deriv in self._derivs.items():
                new_deriv = deriv.copy()
                new_deriv[outside] = deriv.zero()
                new_derivs[key] = new_deriv
        else:
            new_derivs = self._derivs

        result = Qube.__new__(type(self))
        result.__init__(new_values, mask, derivs=new_derivs, example=self)
        return result

    result = self

    if lower is not None:
        result = result.mask_where(result._values < lower, replace=lower, remask=remask)

    if upper is not None:
        if inclusive:
            result = result.mask_where(result._values > upper, replace=upper,
                                       remask=remask)
        else:
            result = result.mask_where(result._values >= upper, replace=upper,
                                       remask=remask)

    return result


def _limit_from_qube(self, limit, masked, op):
    """Interpret the limit.

    Parameters:
        self (Qube): The object for which the limit is to be applied.
        limit (Qube, array-like, float, or int): Limit value. If it is array-like or a
            Qube, it must be broadcastable to self.shape. Also, if it is a Qube, the shape
            of its numerator must be either () or self.numer.
        masked (Qube, array-like, float or int): The value(s) to use where `limit` is
            masked if it is a Qube. If it is array-like or a Qube, it must be
            broadcastable to self.shape. Also, if it is a Qube, the shape of its numerator
            must be either () or self.numer.
        op (str, optional): Operation name to appear in an error message.

    Returns:
        (numpy.ndarray, float, or int): The value of `limit` as an array or scalar. If it
            is an array, its shape will match that of the array `self.values`.
    """

    if isinstance(limit, np.ndarray):
        if self._rank:      # limits apply to items overall, not to individual components
            limit = np.reshape(limit, self._rank * (1,))
        return limit

    if not isinstance(limit, Qube):
        return limit

    if not isinstance(limit._mask, np.ndarray) and limit._mask:     # if fully masked
        return masked

    if limit._drank:
        raise ValueError(self._opstr(op) + ' limit does not support denominators')
    tail = self._drank * (1,)

    if limit._numer:
        if limit._numer != self._numer:
            raise ValueError(self._opstr(op) + ' limit item does not match object: '
                             f'{limit._numer}, {self._numer}')
        tail = limit._numer + tail
    elif self._numer:
        tail = self._nrank * (1,) + tail

    vals = np.broadcast_to(limit._values, self._shape + tail)

    if not np.any(limit._mask):
        return vals

    mask = np.reshape(limit.mask, limit._mask.shape + self._rank * (1,))
    mask = np.broadcast_to(mask, vals.shape)
    vals = vals.copy()
    vals[mask] = masked
    return vals

#########################################################################################
# Convenience methods for range masks and clipping
#########################################################################################

@staticmethod
def is_below(arg, high, inclusive=True):
    """Check if arg is inside a range with upper end at high.

    Parameters:
        arg (object): The value to check.
        high (object): The upper limit of the range.
        inclusive (bool, optional): True to include the upper limit in the range; False to
            exclude it.

    Returns:
        bool: True if arg is inside the range with upper end at high.
    """

    if inclusive:
        return arg <= high
    else:
        return arg < high


@staticmethod
def is_above(arg, high, inclusive=True):
    """Check if arg is outside a range with upper end at high.

    Parameters:
        arg (object): The value to check.
        high (object): The upper limit of the range.
        inclusive (bool, optional): True to include the upper limit in the range; False to
            exclude it.

    Returns:
        bool: True if arg is outside the range with upper end at high.
    """

    if inclusive:
        return arg > high
    else:
        return arg >= high


@staticmethod
def is_outside(arg, low, high, inclusive=True):
    """Check if arg is outside the range low to high.

    Parameters:
        arg (object): The value to check.
        low (object): The lower limit of the range.
        high (object): The upper limit of the range.
        inclusive (bool, optional): True to include the upper limit in the range; False to
            exclude it.

    Returns:
        bool: True if arg is outside the range low to high.
    """

    if inclusive:
        return (arg < low) | (arg > high)
    else:
        return (arg < low) | (arg >= high)


@staticmethod
def is_inside(arg, low, high, inclusive=True):
    """Check if arg is inside the range low to high.

    Parameters:
        arg (object): The value to check.
        low (object): The lower limit of the range.
        high (object): The upper limit of the range.
        inclusive (bool, optional): True to include the upper limit in the range; False to
            exclude it.

    Returns:
        bool: True if arg is inside the range low to high.
    """

    if inclusive:
        return (arg >= low) & (arg <= high)
    else:
        return (arg >= low) & (arg < high)

#########################################################################################