Source code for polymath.extensions.shrinker

################################################################################
# polymath/extensions/shrinker.py: shrink and unshrink operations
################################################################################

import numpy as np
from polymath.qube import Qube
from polymath.scalar import Scalar


def shrink(self, antimask):
    """A 1-D version of this object, containing only the samples in the antimask provided.

    The antimask array value of True indicates that an element should be included; False
    means that is should be discarded. A scalar value of True or False applies to the
    entire object.

    The purpose is to speed up calculations by first eliminating all the objects that are
    masked. Any calculation involving un-shrunken objects should produce the same result
    if the same objects are all shrunken by a common antimask first, the calculation is
    performed, and then the result is un-shrunken afterward.

    Shrunken objects are always converted to read-only.
    """

    # For testing only...
    if Qube._DISABLE_SHRINKING:
        if self._is_scalar or Qube.is_one_true(antimask):
            return self
        return self.mask_where(np.logical_not(antimask))

    # A True antimask leaves an object unchanged
    if Qube.is_one_true(antimask):
        return self

    # If the antimask is a single False value, or if this object is already
    # entirely masked, return a single masked value
    if (Qube.is_one_true(self._mask) or Qube.is_one_false(antimask) or
            not np.any(antimask & self.antimask)):
        obj = self.masked_single().as_readonly()
        if not Qube._DISABLE_CACHE:
            obj._cache['unshrunk'] = self
        return obj

    # If this is a shapeless object, return it as is
    if self._is_scalar:
        self._cache['unshrunk'] = self
        return self

    # Beyond this point, the size of the last axis in the returned object will have the
    # same number of elements as the number of True elements in the antimask.

    # Ensure that this object and the antimask have compatible dimensions. If the antimask
    # has extra dimensions, broadcast self to make it work
    self_rank = self._ndims
    antimask_rank = antimask.ndim
    extras = self_rank - antimask_rank
    if extras < 0:
        self = self.broadcast_to(antimask.shape, recursive=False)
        self_rank = antimask_rank
        extras = 0

    # If self has extra dimensions, these will be retained and only the
    # rightmost axes will be flattened.
    before = self._shape[:extras]      # shape of self retained
    after  = self._shape[extras:]      # shape of self to be masked

    # Make the rightmost axes of self and the mask compatible
    new_after = tuple([max(after[k], antimask.shape[k])
                       for k in range(len(after))])
    new_shape = before + new_after
    if self._shape != new_shape:
        self = self.broadcast_to(new_shape, recursive=False)
    if antimask.shape != new_after:
        antimask = np.broadcast_to(antimask, new_after)

    # Construct the new mask
    if Qube.is_one_false(self._mask):
        mask = np.zeros(antimask.shape, dtype=np.bool_)[antimask]
    else:
        mask = self._mask[extras * (slice(None),) + (antimask, Ellipsis)]

    if np.all(mask):
        obj = self.masked_single().as_readonly()
        obj._cache['unshrunk'] = self
        return obj

    if not np.any(mask):
        mask = False

    # Construct the new object
    obj = Qube.__new__(type(self))
    obj.__init__(self._values[extras * (slice(None),) + (antimask, Ellipsis)], mask,
                 example=self)
    obj.as_readonly()

    for key, deriv in self._derivs.items():
        obj.insert_deriv(key, deriv.shrink(antimask))

    # Cache values to speed things up later
    obj._cache['unshrunk'] = self
    return obj


def unshrink(self, antimask, shape=()):
    """Convert an object to its un-shrunken shape, based on a given antimask.

    If this object was previously shrunken, the antimask must match the one used to shrink
    it. Otherwise, the size of this object's last axis must match the number of True
    values in the antimask.

    Parameters:
        antimask (array-like): The antimask to apply.
        shape (tuple, optional): In cases where the antimask is a literal False, this
            defines the shape of the returned object. Normally, the rightmost axes of the
            returned object match those of the antimask.

    Returns:
        Qube: The un-shrunken object, which will be read-only.
    """

    # For testing only...
    if Qube._DISABLE_SHRINKING:
        return self

    # Get the previous unshrunk version if available and delete from cache
    if Qube._DISABLE_CACHE:
        unshrunk = None
    else:
        unshrunk = self._cache.get('unshrunk', None)
        if unshrunk is not None:
            del self._cache['unshrunk']
            if Qube._IGNORE_UNSHRUNK_AS_CACHED:
                unshrunk = None

    # If the antimask is True, return this as is
    if Qube.is_one_true(antimask):
        return self

    # If the new object is entirely masked, return a shapeless masked object
    if not np.any(antimask) or np.all(self._mask):
        return self.masked_single().broadcast_to(shape)

    # If this object is shapeless, return it as is
    if self._is_scalar:
        return self

    # If we found a cached value, return it
    if unshrunk is not None:
        return unshrunk.mask_where(np.logical_not(antimask))

    # Create the new data array
    new_shape = self._shape[:-1] + antimask.shape
    indx = (self._ndims - 1) * (slice(None),) + (antimask, Ellipsis)
    if self._is_array:
        default = self._default
        if isinstance(default, Qube):
            default = self._default._values

        new_values = np.empty(new_shape + self._item,
                              self._values.dtype)
        new_values[...] = default
        new_values[indx] = self._values    # fill in non-default values

    # ...where single values can be handled by broadcasting...
    else:
        item = Scalar(self._values)
        new_values = item.broadcast_to(new_shape)._values

    # Create the new mask array
    new_mask = np.ones(new_shape, dtype=np.bool_)
    new_mask[indx] = self._mask        # insert the shrunk mask values

    # Create the new object
    obj = Qube.__new__(type(self))
    obj.__init__(new_values, new_mask, example=self)
    obj = obj.as_readonly()

    # Unshrink the derivatives
    for key, deriv in self._derivs.items():
        obj.insert_deriv(key, deriv.unshrink(antimask, shape))

    return obj

################################################################################