Source code for polymath.extensions.shaper

##########################################################################################
# polymath/extensions/shaper.py: re-shaping operations
##########################################################################################

import numpy as np
from polymath.qube import Qube


def reshape(self, shape, *, recursive=True):
    """Return a shallow copy of the object with a new leading shape.

    Parameters:
        shape (tuple or int): A tuple defining the new leading shape. A value of -1 can
            appear at one location in the new shape, and the size of that shape will be
            determined based on this object's size.
        recursive (bool, optional): True to apply the same shape to the derivatives.
            Otherwise, derivatives are deleted from the returned object.

    Returns:
        Qube: A shallow copy with the new shape. If the shape is unchanged, this object is
            returned without modification. The read-only status is preserved.

    Raises:
        ValueError: If the new shape is incompatible with the current shape.
    """

    if np.isscalar(shape):
        shape = (shape,)
    elif not isinstance(shape, tuple):
        shape = tuple(shape)

    new_values = np.reshape(self._values, shape + self._item)
    if isinstance(self._mask, np.ndarray):
        new_mask = self._mask.reshape(shape)
    else:
        new_mask = self._mask

    obj = Qube.__new__(type(self))
    obj.__init__(new_values, new_mask, example=self)
    obj._readonly = self._readonly

    if recursive:
        for key, deriv in self._derivs.items():
            obj.insert_deriv(key, deriv.reshape(shape, recursive=False))

    return obj


def flatten(self, *, recursive=True):
    """Return a shallow copy of the object flattened to one dimension.

    Parameters:
        recursive (bool, optional): True to apply the same flattening to the derivatives.
            Otherwise, derivatives are deleted from the returned object.

    Returns:
        Qube: A shallow copy flattened to one dimension.
    """

    if self._ndims <= 1:
        return self

    count = np.prod(self._shape)
    return self.reshape((count,), recursive=recursive)


def swap_axes(self, axis1, axis2, *, recursive=True):
    """Return a shallow copy of the object with two leading axes swapped.

    Parameters:
        axis1 (int): The first index of the swap. Negative indices are relative to the
            last index before the numerator items begin.
        axis2 (int): The second index of the swap.
        recursive (bool, optional): True to perform the same swap on the derivatives.
            Otherwise, derivatives are deleted from the returned object.

    Returns:
        Qube: A shallow copy with the specified axes swapped.

    Raises:
        ValueError: If either axis is out of range.
    """

    self._require_axis_in_range(axis1, self._ndims, 'swap_axes()', name='axis1')
    self._require_axis_in_range(axis2, self._ndims, 'swap_axes()', name='axis2')

    a1 = axis1 % self._ndims
    a2 = axis2 % self._ndims
    if a1 == a2:
        return self

    # Swap the axes of values and mask
    new_values = np.swapaxes(self._values, a1, a2)
    if isinstance(self._mask, np.ndarray):
        new_mask = self._mask.swapaxes(a1, a2)
    else:
        new_mask = self._mask

    obj = Qube.__new__(type(self))
    obj.__init__(new_values, new_mask, example=self)
    obj._readonly = self._readonly

    if recursive:
        for key, deriv in self._derivs.items():
            obj.insert_deriv(key, deriv.swap_axes(a1, a2, recursive=False))

    return obj


def roll_axis(self, axis, start=0, *, recursive=True, rank=None):
    """A shallow copy of the object with the specified axis rolled to a new position.

    Parameters:
        axis (int): The axis to roll.
        start (int, optional): The axis will be rolled to fall in front of this axis.
        recursive (bool, optional): True to perform the same axis roll on the derivatives.
            Otherwise, derivatives are deleted from the returned object.
        rank (int, optional): Rank to assume for the object, which could be larger than
            len(self.shape) because of broadcasting.

    Returns:
        Qube: A shallow copy with the axis rolled to the new position.

    Raises:
        ValueError: If the rank is too small for the object shape.
        ValueError: If the axis or start is out of range.
    """

    # Validate the rank
    rank = self._ndims if rank is None else rank
    if rank < self._ndims:
        opstr = self._opstr('roll_axis()')
        raise ValueError(f'{opstr} rank {rank} is too small for shape {self._shape}')

    # Identify the axis to roll, which could be negative
    self._require_axis_in_range(axis, rank, 'roll_axis()')
    a1 = axis % rank

    # Identify the start axis, which could be negative; note start == rank is valid
    if start != rank:
        self._require_axis_in_range(start, rank, 'roll_axis()', 'start')
    a2 = start + rank if start < 0 else start

    # Add missing axes if necessary
    if self._ndims < rank:
        self = self.reshape((rank - self._ndims) * (1,) + self._shape,
                            recursive=recursive)

    # Roll the values and mask of the object
    new_values = np.rollaxis(self._values, a1, a2)
    if isinstance(self._mask, np.ndarray):
        new_mask = np.rollaxis(self._mask, a1, a2)
    else:
        new_mask = self._mask

    obj = Qube.__new__(type(self))
    obj.__init__(new_values, new_mask, example=self)
    obj._readonly = self._readonly

    if recursive:
        for key, deriv in self._derivs.items():
            obj.insert_deriv(key, deriv.roll_axis(a1, a2, recursive=False, rank=rank))

    return obj


def move_axis(self, source, destination, *, recursive=True, rank=None):
    """A shallow copy of the object with the specified axis moved to a new position.

    Parameters:
        source (int or tuple): Axis to move or tuple of axes to move.
        destination (int or tuple): Destination of moved axis or axes.
        recursive (bool, optional): True to perform the same axis move on the derivatives.
            Otherwise, derivatives are deleted from the returned object.
        rank (int, optional): Rank to assume for the object, which could be larger than
            len(self.shape) because of broadcasting.

    Returns:
        Qube: A shallow copy with the specified axis moved to the new position.

    Raises:
        ValueError: If the rank is too small for the object shape.
        ValueError: If any axis is out of range.
    """

    # Validate the rank
    rank = self._ndims if rank is None else rank
    if rank < self._ndims:
        opstr = self._opstr('move_axis()')
        raise ValueError(f'{opstr} rank {rank} is too small for shape {self._shape}')

    # Identify the axes, which could be negative
    if np.isscalar(source):
        source = (source,)
    if np.isscalar(destination):
        destination = (destination,)

    for axis in source:
        self._require_axis_in_range(axis, rank, 'move_axis()', 'source')
    for axis in destination:
        self._require_axis_in_range(axis, rank, 'move_axis()', 'destination')

    source = tuple(x % rank for x in source)
    destination = tuple(x % rank for x in destination)

    # Add missing axes if necessary
    if self._ndims < rank:
        self = self.reshape((rank - self._ndims) * (1,) + self._shape,
                            recursive=recursive)

    # Move the values and mask of the object
    new_values = np.moveaxis(self._values, source, destination)
    if isinstance(self._mask, np.ndarray):
        new_mask = np.moveaxis(self._mask, source, destination)
    else:
        new_mask = self._mask

    obj = Qube.__new__(type(self))
    obj.__init__(new_values, new_mask, example=self)
    obj._readonly = self._readonly

    if recursive:
        for key, deriv in self._derivs.items():
            obj.insert_deriv(key, deriv.move_axis(source, destination,
                                                  recursive=False, rank=rank))

    return obj


@staticmethod
def stack(*args, recursive=True):
    """Stack objects into one with a new leading axis.

    Parameters:
        *args: Any number of Scalars or arguments that can be casted to Scalars. They need
            not have the same shape, but it must be possible to cast them to the same
            shape. A value of None is converted to a zero-valued Scalar that matches the
            denominator shape of the other arguments.
        recursive (bool, optional): True to include all the derivatives. The returned
            object will have derivatives representing the union of all the derivatives
            found amongst the scalars.

    Returns:
        Qube: A stacked object with a new leading axis.

    Raises:
        TypeError: If an unexpected keyword argument is provided.
        ValueError: If the arguments have incompatible denominators.
    """

    args = list(args)

    # Get the type and unit if any
    # Only use class Qube if no suitable subclass was found
    floats_found = False
    ints_found = False

    float_arg = None
    int_arg = None
    bool_arg = None

    unit = None
    denom = None
    subclass_indx = None

    for i, arg in enumerate(args):
        if arg is None:
            continue

        qubed = False
        if not isinstance(arg, Qube):
            arg = Qube(arg)
            args[i] = arg
            qubed = True

        if denom is None:
            denom = arg._denom
        elif denom != arg._denom:
            raise ValueError('incompatible denominator shapes for stack(): '
                             f'{denom}, {arg._denom}')

        if arg.is_float():
            floats_found = True
            if float_arg is None or not qubed:
                float_arg = arg
                subclass_indx = i
        elif arg.is_int() and float_arg is None:
            ints_found = True
            if int_arg is None or not qubed:
                int_arg = arg
                subclass_indx = i
        elif arg.is_bool() and int_arg is None and float_arg is None:
            if bool_arg is None or not qubed:
                bool_arg = arg
                subclass_indx = i

        if arg._unit is not None:
            if unit is None:
                unit = arg._unit
            else:
                arg.confirm_unit(unit)

    drank = len(denom)

    # Convert to subclass and type
    for i, arg in enumerate(args):
        if arg is None:                 # Used as placehold for derivs
            continue

        args[i] = args[subclass_indx].as_this_type(arg, recursive=recursive,
                                                   coerce=False)

    # Broadcast all inputs into a common shape
    args = Qube.broadcast(*args, recursive=True)

    # Determine what type of mask is needed:
    mask_true_found = False
    mask_false_found = False
    mask_array_found = False
    for arg in args:
        if arg is None:
            continue
        elif Qube.is_one_true(arg._mask):
            mask_true_found = True
        elif Qube.is_one_false(arg._mask):
            mask_false_found = True
        else:
            mask_array_found = True

    # Construct the mask
    if mask_array_found or (mask_false_found and mask_true_found):
        mask = np.zeros((len(args),) + args[subclass_indx].shape, dtype=np.bool_)
        for i in range(len(args)):
            if args[i] is None:
                mask[i] = False
            else:
                mask[i] = args[i]._mask
    else:
        mask = mask_true_found

    # Construct the array
    if floats_found:
        dtype = np.float64
    elif ints_found:
        dtype = np.int_
    else:
        dtype = np.bool_

    values = np.empty((len(args),) + np.shape(args[subclass_indx]._values), dtype=dtype)
    for i in range(len(args)):
        if args[i] is None:
            values[i] = 0
        else:
            values[i] = args[i]._values

    # Construct the result
    result = Qube.__new__(type(args[subclass_indx]))
    result.__init__(values, mask, unit=unit, drank=drank)

    # Fill in derivatives if necessary
    if recursive:
        keys = []
        for arg in args:
            if arg is None:
                continue
            keys += arg._derivs.keys()

        keys = set(keys)        # remove duplicates

        derivs = {}
        for key in keys:
            deriv_list = []
            for arg in args:
                if arg is None:
                    deriv_list.append(None)
                else:
                    deriv_list.append(arg._derivs.get(key, None))

            derivs[key] = Qube.stack(*deriv_list, recursive=False)

        result.insert_derivs(derivs)

    return result

##########################################################################################