Source code for polymath.extensions.iterator

################################################################################
# polymath/extensions/iterator.py: iterator over Qube objects
################################################################################

import itertools
import numpy as np


class QubeIterator(object):
    """Provide iteration across the first axis of a Qube object.

    This iterator allows iteration over elements along the first axis of a Qube object,
    similar to how NumPy arrays can be iterated.

    Attributes:
        obj (list or Qube): The object to iterate over.
        stop (int): The number of elements to iterate through.
        index (int): The current iteration position.
    """

    def __init__(self, obj):
        """Initialize the iterator with a Qube object.

        Parameters:
            obj (Qube): The Qube object to iterate over.
        """

        if not obj._shape:
            self.obj = [obj]
            self.stop = 1
        else:
            self.obj = obj
            self.stop = obj._shape[0]

        self.index = -1

    def __iter__(self):
        """Return the iterator object itself.

        Returns:
            QubeIterator: This iterator instance.
        """

        self.index = -1
        return self

    def __next__(self):
        """Return the next item in the iteration.

        Returns:
            Qube: The next element in the iteration.

        Raises:
            StopIteration: When iteration is complete.
        """

        self.index += 1
        if self.index >= self.stop:
            raise StopIteration

        return self.obj[self.index]


class QubeNDIterator(object):
    """Provide iteration across all axes of a Qube object.

    This iterator allows iteration over all elements in a multi-dimensional Qube object,
    returning both the index tuple and the value at that index.

    Attributes:
        obj (ndarray): The object to iterate over.
        shape (tuple): The shape of the object.
        iterator (iterator): The underlying iterator.
    """

    def __init__(self, obj):
        """Initialize the iterator with a Qube object.

        Parameters:
            obj (Qube): The Qube object to iterate over.
        """

        if not obj._shape:
            self.obj = np.array([obj], dtype='object')
            self.shape = (1,)
        else:
            self.obj = obj
            self.shape = obj._shape

        self.iterator = None

    def __iter__(self):
        """Return the iterator object itself.

        Returns:
            QubeNDIterator: This iterator instance.
        """

        self.iterator = itertools.product(*[range(s) for s in self.shape])
        return self

    def __next__(self):
        """Return the next item in the iteration.

        Returns:
            tuple: A tuple containing (index_tuple, item_at_index).

        Raises:
            StopIteration: When iteration is complete.
        """

        indx = self.iterator.__next__()
        return (indx, self.obj[indx])


def __iter__(self):
    """Return an iterator over the first axis of this object.

    Returns:
        QubeIterator: An iterator that iterates over the first axis of the object.
    """

    return QubeIterator(obj=self)


def ndenumerate(self):
    """Iterate across all axes of this object.

    This method provides an iterator that returns tuples containing the index
    and the corresponding item at that index.

    Returns:
        QubeNDIterator: An iterator yielding (index_tuple, item_at_index) pairs.
    """

    return QubeNDIterator(obj=self)

################################################################################