"""
Pipeline modules for stacking and subsampling of images.
"""
import time
import warnings
from typing import List, Optional, Tuple
import numpy as np
from typeguard import typechecked
from pynpoint.core.processing import ProcessingModule
from pynpoint.util.module import progress, memory_frames, stack_angles, angle_average
from pynpoint.util.image import rotate_images
[docs]class StackAndSubsetModule(ProcessingModule):
"""
Pipeline module for stacking subsets of images and/or selecting a random sample of images.
"""
@typechecked
def __init__(self,
name_in: str,
image_in_tag: str,
image_out_tag: str,
random: Optional[int] = None,
stacking: Optional[int] = None,
combine: str = 'mean',
max_rotation: Optional[float] = None) -> None:
"""
Parameters
----------
name_in : str
Unique name of the module instance.
image_in_tag : str
Tag of the database entry that is read as input.
image_out_tag : str
Tag of the database entry that is written as output. Should be different from
*image_in_tag*.
random : int, None
Number of random images. All images are used if set to None.
stacking : int, None
Number of stacked images per subset. No stacking is applied if set to None.
combine : str
Method for combining images ('mean' or 'median'). The angles are always mean-combined.
max_rotation : float, None
Maximum allowed field rotation (deg) throughout each subset of stacked images when
`stacking` is not None. No restriction on the field rotation is applied if set to
None.
Returns
-------
NoneType
None
"""
super().__init__(name_in)
self.m_image_in_port = self.add_input_port(image_in_tag)
self.m_image_out_port = self.add_output_port(image_out_tag)
self.m_random = random
self.m_stacking = stacking
self.m_combine = combine
self.m_max_rotation = max_rotation
if self.m_stacking is None and self.m_random is None:
warnings.warn('Both \'stacking\' and \'random\' are set to None.')
[docs] @typechecked
def run(self) -> None:
"""
Run method of the module. Stacks subsets of images and/or selects a random subset. Also
the parallactic angles are mean-combined if images are stacked.
Returns
-------
NoneType
None
"""
@typechecked
def _stack_subsets(nimages: int,
im_shape: Tuple[int, ...],
parang: np.ndarray) -> Tuple[Tuple[int, ...], np.ndarray, np.ndarray]:
im_new = None
parang_new = None
if self.m_stacking is not None:
if self.m_max_rotation is not None:
frames = stack_angles(self.m_stacking, parang, self.m_max_rotation)
else:
frames = memory_frames(self.m_stacking, nimages)
nimages_new = np.size(frames)-1
if parang is None:
parang_new = None
else:
parang_new = np.zeros(nimages_new)
im_new = np.zeros((nimages_new, im_shape[1], im_shape[2]))
start_time = time.time()
for i in range(nimages_new):
progress(i, nimages_new, 'Stacking subsets of images...', start_time)
if parang is not None:
# parang_new[i] = np.mean(parang[frames[i]:frames[i+1]])
parang_new[i] = angle_average(parang[frames[i]:frames[i+1]])
im_subset = self.m_image_in_port[frames[i]:frames[i+1], ]
if self.m_combine == 'mean':
im_new[i, ] = np.mean(im_subset, axis=0)
elif self.m_combine == 'median':
im_new[i, ] = np.median(im_subset, axis=0)
im_shape = im_new.shape
else:
if parang is not None:
parang_new = np.copy(parang)
return im_shape, im_new, parang_new
@typechecked
def _random_subset(im_shape: Tuple[int, ...],
im_new: np.ndarray,
parang_new: np.ndarray) -> Tuple[int, np.ndarray, np.ndarray]:
if self.m_random is not None:
choice = np.random.choice(im_shape[0], self.m_random, replace=False)
choice = list(np.sort(choice))
if parang_new is None:
parang_new = None
else:
parang_new = parang_new[choice]
if self.m_stacking is None:
im_new = self.m_image_in_port[list(choice), ]
else:
im_new = im_new[choice, ]
if self.m_random is None and self.m_stacking is None:
nimages = 0
elif im_new.ndim == 2:
nimages = 1
elif im_new.ndim == 3:
nimages = im_new.shape[0]
return nimages, im_new, parang_new
non_static = self.m_image_in_port.get_all_non_static_attributes()
im_shape = self.m_image_in_port.get_shape()
nimages = im_shape[0]
if self.m_random is not None:
if self.m_stacking is None and im_shape[0] < self.m_random:
raise ValueError('The number of images of the destination subset is larger than '
'the number of images in the source.')
if self.m_stacking is not None and \
int(float(im_shape[0])/float(self.m_stacking)) < self.m_random:
raise ValueError('The number of images of the destination subset is larger than '
'the number of images in the stacked source.')
if 'PARANG' in non_static:
parang = self.m_image_in_port.get_attribute('PARANG')
else:
parang = None
im_shape, im_new, parang_new = _stack_subsets(nimages, im_shape, parang)
nimages, im_new, parang_new = _random_subset(im_shape, im_new, parang_new)
if self.m_random or self.m_stacking:
self.m_image_out_port.set_all(im_new, keep_attributes=True)
self.m_image_out_port.copy_attributes(self.m_image_in_port)
self.m_image_out_port.add_attribute('INDEX', np.arange(0, nimages, 1), static=False)
if parang_new is not None:
self.m_image_out_port.add_attribute('PARANG', parang_new, static=False)
if 'NFRAMES' in non_static:
self.m_image_out_port.del_attribute('NFRAMES')
history = f'stacking = {self.m_stacking}, random = {self.m_random}'
self.m_image_out_port.add_history('StackAndSubsetModule', history)
self.m_image_out_port.close_port()
[docs]class StackCubesModule(ProcessingModule):
"""
Pipeline module for calculating the mean or median of each original data cube associated with a
database tag.
"""
@typechecked
def __init__(self,
name_in: str,
image_in_tag: str,
image_out_tag: str,
combine: str = 'mean') -> None:
"""
Parameters
----------
name_in : str
Unique name of the module instance.
image_in_tag : str
Tag of the database entry that is read as input.
image_out_tag : str
Tag of the database entry with the mean or median collapsed images that are written
as output. Should be different from *image_in_tag*.
combine : str
Method to combine the images ('mean' or 'median').
Returns
-------
NoneType
None
"""
super().__init__(name_in)
self.m_image_in_port = self.add_input_port(image_in_tag)
self.m_image_out_port = self.add_output_port(image_out_tag)
self.m_combine = combine
[docs] @typechecked
def run(self) -> None:
"""
Run method of the module. Uses the ``NFRAMES`` attribute to select the images of each cube,
calculates the mean or median of each cube, and saves the data and attributes.
Returns
-------
NoneType
None
"""
if self.m_image_in_port.tag == self.m_image_out_port.tag:
raise ValueError('Input and output port should have a different tag.')
non_static = self.m_image_in_port.get_all_non_static_attributes()
nframes = self.m_image_in_port.get_attribute('NFRAMES')
if 'PARANG' in non_static:
parang = self.m_image_in_port.get_attribute('PARANG')
else:
parang = None
current = 0
parang_new = []
start_time = time.time()
for i, frames in enumerate(nframes):
progress(i, len(nframes), 'Stacking images per FITS cube...', start_time)
if self.m_combine == 'mean':
im_stack = np.mean(self.m_image_in_port[current:current+frames, ], axis=0)
elif self.m_combine == 'median':
im_stack = np.median(self.m_image_in_port[current:current+frames, ], axis=0)
self.m_image_out_port.append(im_stack, data_dim=3)
if parang is not None:
parang_new.append(np.mean(parang[current:current+frames]))
current += frames
nimages = np.size(nframes)
self.m_image_out_port.copy_attributes(self.m_image_in_port)
if 'INDEX' in non_static:
index = np.arange(0, nimages, 1, dtype=int)
self.m_image_out_port.add_attribute('INDEX', index, static=False)
if 'NFRAMES' in non_static:
nframes = np.ones(nimages, dtype=int)
self.m_image_out_port.add_attribute('NFRAMES', nframes, static=False)
if 'PARANG' in non_static:
self.m_image_out_port.add_attribute('PARANG', parang_new, static=False)
self.m_image_out_port.close_port()
[docs]class DerotateAndStackModule(ProcessingModule):
"""
Pipeline module for derotating and/or stacking (i.e., taking the median or average) of the
images, either along the time or the wavelengths dimension.
"""
@typechecked
def __init__(self,
name_in: str,
image_in_tag: str,
image_out_tag: str,
derotate: bool = True,
stack: Optional[str] = None,
extra_rot: float = 0.,
dimension: str = 'time') -> None:
"""
Parameters
----------
name_in : str
Unique name of the module instance.
image_in_tag : str
Tag of the database entry that is read as input.
image_out_tag : str
Tag of the database entry that is written as output. The shape of the output data is
equal to the data from ``image_in_tag``. If the argument of ``stack`` is not None,
then the size of the collapsed dimension is equal to 1.
derotate : bool
Derotate the images with the ``PARANG`` attribute.
stack : str
Type of stacking applied after optional derotation ('mean', 'median', or None for no
stacking).
extra_rot : float
Additional rotation angle of the images in clockwise direction (deg).
dimension : str
Dimension along which the images are stacked. Can either be 'time' or 'wavelength'. If
the ``image_in_tag`` has three dimensions then ``dimension`` is always fixed to 'time'.
Returns
-------
NoneType
None
"""
super().__init__(name_in)
self.m_image_in_port = self.add_input_port(image_in_tag)
self.m_image_out_port = self.add_output_port(image_out_tag)
self.m_derotate = derotate
self.m_stack = stack
self.m_extra_rot = extra_rot
self.m_dimension = dimension
[docs] @typechecked
def run(self) -> None:
"""
Run method of the module. Uses the ``PARANG`` attributes to derotate the images (if
``derotate`` is set to ``True``) and applies an optional mean or median stacking
along the time or wavelengths dimension afterwards.
Returns
-------
NoneType
None
"""
@typechecked
def _initialize(ndim: int,
npix: int) -> Tuple[int, np.ndarray, Optional[np.ndarray],
Optional[np.ndarray]]:
if ndim == 2:
nimages = 1
elif ndim == 3:
nimages = self.m_image_in_port.get_shape()[-3]
if self.m_stack == 'median':
frames = np.array([0, nimages])
else:
frames = memory_frames(memory, nimages)
elif ndim == 4:
nimages = self.m_image_in_port.get_shape()[-3]
nwave = self.m_image_in_port.get_shape()[-4]
if self.m_dimension == 'time':
frames = np.linspace(0, nwave, nwave+1)
elif self.m_dimension == 'wavelength':
frames = np.linspace(0, nimages, nimages+1)
else:
raise ValueError('The dimension should be set to \'time\' or \'wavelength\'.')
if self.m_stack == 'mean':
if ndim == 4:
if self.m_dimension == 'time':
im_tot = np.zeros((nwave, npix, npix))
elif self.m_dimension == 'wavelength':
im_tot = np.zeros((nimages, npix, npix))
else:
im_tot = np.zeros((npix, npix))
else:
im_tot = None
if self.m_stack is None and ndim == 4:
im_none = np.zeros((nwave, nimages, npix, npix))
else:
im_none = None
return nimages, frames, im_tot, im_none
memory = self._m_config_port.get_attribute('MEMORY')
if self.m_derotate:
parang = self.m_image_in_port.get_attribute('PARANG')
ndim = self.m_image_in_port.get_ndim()
npix = self.m_image_in_port.get_shape()[-2]
nimages, frames, im_tot, im_none = _initialize(ndim, npix)
start_time = time.time()
for i, _ in enumerate(frames[:-1]):
progress(i, len(frames[:-1]), 'Derotating and/or stacking images...', start_time)
if ndim == 3:
# Get the images and ensure they have the correct 3D shape with the following
# three dimensions: (batch_size, height, width)
images = self.m_image_in_port[frames[i]:frames[i+1], ]
elif ndim == 4:
# Process all time frames per exposure at once
if self.m_dimension == 'time':
images = self.m_image_in_port[i, :, ]
elif self.m_dimension == 'wavelength':
images = self.m_image_in_port[:, i, ]
if self.m_derotate:
if ndim == 4:
if self.m_dimension == 'time':
angles = -1.*parang + self.m_extra_rot
elif self.m_dimension == 'wavelength':
n_wavel = self.m_image_in_port.get_shape()[-4]
angles = np.full(n_wavel, -1.*parang[i]) + self.m_extra_rot
else:
angles = -parang[frames[i]:frames[i+1]]+self.m_extra_rot
images = rotate_images(images, angles)
if self.m_stack is None:
if ndim == 2:
self.m_image_out_port.set_all(images[np.newaxis, ...])
elif ndim == 3:
self.m_image_out_port.append(images, data_dim=3)
elif ndim == 4:
if self.m_dimension == 'time':
im_none[i] = images
elif self.m_dimension == 'wavelength':
im_none[:, i] = images
elif self.m_stack == 'mean':
if ndim == 4:
im_tot[i] = np.sum(images, axis=0)
else:
im_tot += np.sum(images, axis=0)
if self.m_stack == 'mean':
if ndim == 4:
im_stack = im_tot/float(im_tot.shape[0])
if self.m_dimension == 'time':
self.m_image_out_port.set_all(im_stack[:, np.newaxis, ...])
elif self.m_dimension == 'wavelength':
self.m_image_out_port.set_all(im_stack[np.newaxis, ...])
else:
im_stack = im_tot/float(nimages)
self.m_image_out_port.set_all(im_stack[np.newaxis, ...])
elif self.m_stack == 'median':
if ndim == 4:
images = self.m_image_in_port[:]
if self.m_dimension == 'time':
im_stack = np.median(images, axis=1)
self.m_image_out_port.set_all(im_stack[:, np.newaxis, ...])
elif self.m_dimension == 'wavelength':
im_stack = np.median(images, axis=0)
self.m_image_out_port.set_all(im_stack[np.newaxis, ...])
else:
im_stack = np.median(images, axis=0)
self.m_image_out_port.set_all(im_stack[np.newaxis, ...])
elif self.m_stack is None and ndim == 4:
if self.m_dimension == 'time':
self.m_image_out_port.set_all(im_none)
elif self.m_dimension == 'wavelength':
self.m_image_out_port.set_all(im_none)
if self.m_derotate or self.m_stack is not None:
self.m_image_out_port.copy_attributes(self.m_image_in_port)
self.m_image_out_port.close_port()