Source code for pynpoint.util.multipca

"""
Capsule for multiprocessing of the PSF subtraction with PCA. Residuals are created in parallel for
a range of principal components for which the PCA basis is required as input.
"""

import sys
import multiprocessing

from typing import List, Optional, Tuple, Union

import numpy as np

from typeguard import typechecked
from sklearn.decomposition import PCA

from pynpoint.core.dataio import OutputPort
from pynpoint.util.multiproc import TaskProcessor, TaskCreator, TaskWriter, TaskResult, \
                                    TaskInput, MultiprocessingCapsule, to_slice
from pynpoint.util.postproc import postprocessor
from pynpoint.util.residuals import combine_residuals


[docs]class PcaTaskCreator(TaskCreator): """ The TaskCreator of the PCA multiprocessing. Creates one task for each principal component number. Does not require an input port since the data is directly given to the task processors. """ @typechecked def __init__(self, tasks_queue_in: multiprocessing.JoinableQueue, num_proc: int, pca_numbers: Union[np.ndarray, tuple]) -> None: """ Parameters ---------- tasks_queue_in : multiprocessing.queues.JoinableQueue Input task queue. num_proc : int Number of processors. pca_numbers : np.ndarray, tuple Principal components for which the residuals are computed. Returns ------- NoneType None """ super(PcaTaskCreator, self).__init__(None, tasks_queue_in, None, num_proc) self.m_pca_numbers = pca_numbers
[docs] @typechecked def run(self) -> None: """ Run method of PcaTaskCreator. Returns ------- NoneType None """ if isinstance(self.m_pca_numbers, tuple): for i, pca_first in enumerate(self.m_pca_numbers[0]): for j, pca_secon in enumerate(self.m_pca_numbers[1]): parameters = (((i, i+1, None), (j, j+1, None), (None, None, None)), ) self.m_task_queue.put(TaskInput(tuple((pca_first, pca_secon)), parameters)) self.create_poison_pills() else: for i, pca_number in enumerate(self.m_pca_numbers): parameters = (((i, i+1, None), (None, None, None), (None, None, None)), ) self.m_task_queue.put(TaskInput(pca_number, parameters)) self.create_poison_pills()
[docs]class PcaTaskProcessor(TaskProcessor): """ The TaskProcessor of the PCA multiprocessing is the core of the parallelization. An instance of this class will calculate one forward and backward PCA transformation given the pre-trained scikit-learn PCA model. It does not get data from the TaskCreator but uses its own copy of the input data, which are the same and independent for each task. The following residuals can be created: * Mean residuals -- requirements[0] = True * Median residuals -- requirements[1] = True * Noise-weighted residuals -- requirements[2] = True * Clipped mean of the residuals -- requirements[3] = True """ @typechecked def __init__(self, tasks_queue_in: multiprocessing.JoinableQueue, result_queue_in: multiprocessing.JoinableQueue, star_reshape: np.ndarray, angles: np.ndarray, scales: Optional[np.ndarray], pca_model: Optional[PCA], im_shape: tuple, indices: Optional[np.ndarray], requirements: Tuple[bool, bool, bool, bool], processing_type: str) -> None: """ Parameters ---------- tasks_queue_in : multiprocessing.queues.JoinableQueue Input task queue. result_queue_in : multiprocessing.queues.JoinableQueue Input result queue. star_reshape : np.ndarray Reshaped (2D) stack of images. angles : np.ndarray Derotation angles (deg). scales : np.ndarray scaling factors pca_model : sklearn.decomposition.pca.PCA PCA object with the basis. im_shape : tuple(int, int, int) Original shape of the stack of images. indices : np.ndarray Non-masked image indices. requirements : tuple(bool, bool, bool, bool) Required output residuals. processing_type : str selected processing type. Returns ------- NoneType None """ super(PcaTaskProcessor, self).__init__(tasks_queue_in, result_queue_in) self.m_star_reshape = star_reshape self.m_pca_model = pca_model self.m_angles = angles self.m_scales = scales self.m_im_shape = im_shape self.m_indices = indices self.m_requirements = requirements self.m_processing_type = processing_type
[docs] @typechecked def run_job(self, tmp_task: TaskInput) -> TaskResult: """ Run method of PcaTaskProcessor. Parameters ---------- tmp_task : pynpoint.util.multiproc.TaskInput Input task. Returns ------- pynpoint.util.multiproc.TaskResult Output residuals. """ # correct data type of pca_number if necessary if isinstance(tmp_task.m_input_data, tuple): pca_number = tmp_task.m_input_data else: pca_number = int(tmp_task.m_input_data) residuals, res_rot = postprocessor(images=self.m_star_reshape, angles=self.m_angles, scales=self.m_scales, pca_number=pca_number, pca_sklearn=self.m_pca_model, im_shape=self.m_im_shape, indices=self.m_indices, processing_type=self.m_processing_type) # differentiate between IFS data or Mono-Wavelength data if res_rot.ndim == 3: res_output = np.zeros((4, res_rot.shape[-2], res_rot.shape[-1])) else: res_output = np.zeros((4, len(self.m_star_reshape), res_rot.shape[-2], res_rot.shape[-1])) if self.m_requirements[0]: res_output[0, ] = combine_residuals(method='mean', res_rot=res_rot) if self.m_requirements[1]: res_output[1, ] = combine_residuals(method='median', res_rot=res_rot) if self.m_requirements[2]: res_output[2, ] = combine_residuals(method='weighted', res_rot=res_rot, residuals=residuals, angles=self.m_angles) if self.m_requirements[3]: res_output[3, ] = combine_residuals(method='clipped', res_rot=res_rot) sys.stdout.write('.') sys.stdout.flush() return TaskResult(res_output, tmp_task.m_job_parameter[0])
[docs]class PcaTaskWriter(TaskWriter): """ The TaskWriter of the PCA parallelization. Four different ports are used to save the results of the task processors (mean, median, weighted, and clipped). """ @typechecked def __init__(self, result_queue_in: multiprocessing.JoinableQueue, mean_out_port: Optional[OutputPort], median_out_port: Optional[OutputPort], weighted_out_port: Optional[OutputPort], clip_out_port: Optional[OutputPort], data_mutex_in: multiprocessing.Lock, requirements: Tuple[bool, bool, bool, bool]) -> None: """ Constructor of PcaTaskWriter. Parameters ---------- result_queue_in : multiprocessing.queues.JoinableQueue Input result queue. mean_out_port : pynpoint.core.dataio.OutputPort Output port with the mean residuals. Not used if set to None. median_out_port : pynpoint.core.dataio.OutputPort Output port with the median residuals. Not used if set to None. weighted_out_port : pynpoint.core.dataio.OutputPort Output port with the noise-weighted residuals. Not used if set to None. clip_out_port : pynpoint.core.dataio.OutputPort Output port with the clipped mean residuals. Not used if set to None. data_mutex_in : multiprocessing.synchronize.Lock A mutual exclusion variable which ensure that no read and write simultaneously occur. requirements : tuple(bool, bool, bool, bool) Required output residuals. Returns ------- NoneType None """ super(PcaTaskWriter, self).__init__(result_queue_in, None, data_mutex_in) self.m_mean_out_port = mean_out_port self.m_median_out_port = median_out_port self.m_weighted_out_port = weighted_out_port self.m_clip_out_port = clip_out_port self.m_requirements = requirements
[docs] @typechecked def run(self) -> None: """ Run method of PcaTaskWriter. Writes the residuals to the output ports. Returns ------- NoneType None """ while True: next_result = self.m_result_queue.get() poison_pill_case = self.check_poison_pill(next_result) if poison_pill_case == 1: break if poison_pill_case == 2: continue with self.m_data_mutex: res_slice = to_slice(next_result.m_position) if next_result.m_position[1][0] is None: res_slice = (next_result.m_position[0][0]) else: res_slice = (next_result.m_position[0][0], next_result.m_position[1][0]) if self.m_requirements[0]: self.m_mean_out_port._check_status_and_activate() self.m_mean_out_port[res_slice] = next_result.m_data_array[0] self.m_mean_out_port.close_port() if self.m_requirements[1]: self.m_median_out_port._check_status_and_activate() self.m_median_out_port[res_slice] = next_result.m_data_array[1] self.m_median_out_port.close_port() if self.m_requirements[2]: self.m_weighted_out_port._check_status_and_activate() self.m_weighted_out_port[res_slice] = next_result.m_data_array[2] self.m_weighted_out_port.close_port() if self.m_requirements[3]: self.m_clip_out_port._check_status_and_activate() self.m_clip_out_port[res_slice] = next_result.m_data_array[3] self.m_clip_out_port.close_port() self.m_result_queue.task_done()
[docs]class PcaMultiprocessingCapsule(MultiprocessingCapsule): """ Capsule for PCA multiprocessing with the poison pill pattern. """ @typechecked def __init__(self, mean_out_port: Optional[OutputPort], median_out_port: Optional[OutputPort], weighted_out_port: Optional[OutputPort], clip_out_port: Optional[OutputPort], num_proc: int, pca_numbers: Union[tuple, np.ndarray], pca_model: Optional[PCA], star_reshape: np.ndarray, angles: np.ndarray, scales: Optional[np.ndarray], im_shape: tuple, indices: Optional[np.ndarray], processing_type: str) -> None: """ Constructor of PcaMultiprocessingCapsule. Parameters ---------- mean_out_port : pynpoint.core.dataio.OutputPort Output port for the mean residuals. median_out_port : pynpoint.core.dataio.OutputPort Output port for the median residuals. weighted_out_port : pynpoint.core.dataio.OutputPort Output port for the noise-weighted residuals. clip_out_port : pynpoint.core.dataio.OutputPort Output port for the mean clipped residuals. num_proc : int Number of processors. pca_numbers : np.ndarray Number of principal components. pca_model : sklearn.decomposition.pca.PCA PCA object with the basis. star_reshape : np.ndarray Reshaped (2D) input images. angles : np.ndarray Derotation angles (deg). scales : np.ndarray scaling factors. im_shape : tuple(int, int, int) Original shape of the input images. indices : np.ndarray Non-masked pixel indices. processing_type : str selection of processing type Returns ------- NoneType None """ self.m_mean_out_port = mean_out_port self.m_median_out_port = median_out_port self.m_weighted_out_port = weighted_out_port self.m_clip_out_port = clip_out_port self.m_pca_numbers = pca_numbers self.m_pca_model = pca_model self.m_star_reshape = star_reshape self.m_angles = angles self.m_scales = scales self.m_im_shape = im_shape self.m_indices = indices self.m_processing_type = processing_type self.m_requirements = [False, False, False, False] if self.m_mean_out_port is not None: self.m_requirements[0] = True if self.m_median_out_port is not None: self.m_requirements[1] = True if self.m_weighted_out_port is not None: self.m_requirements[2] = True if self.m_clip_out_port is not None: self.m_requirements[3] = True self.m_requirements = tuple(self.m_requirements) super(PcaMultiprocessingCapsule, self).__init__(None, None, num_proc)
[docs] @typechecked def create_writer(self, image_out_port: None) -> PcaTaskWriter: """ Method to create an instance of PcaTaskWriter. Parameters ---------- image_out_port : None Output port, not used. Returns ------- pynpoint.util.multipca.PcaTaskWriter PCA task writer. """ return PcaTaskWriter(self.m_result_queue, self.m_mean_out_port, self.m_median_out_port, self.m_weighted_out_port, self.m_clip_out_port, self.m_data_mutex, self.m_requirements)
[docs] @typechecked def init_creator(self, image_in_port: None) -> PcaTaskCreator: """ Method to create an instance of PcaTaskCreator. Parameters ---------- image_in_port : None Input port, not used. Returns ------- pynpoint.util.multipca.PcaTaskCreator PCA task creator. """ return PcaTaskCreator(self.m_tasks_queue, self.m_num_proc, self.m_pca_numbers)
[docs] @typechecked def create_processors(self) -> List[PcaTaskProcessor]: """ Method to create a list of instances of PcaTaskProcessor. Returns ------- list(pynpoint.util.multipca.PcaTaskProcessor, ) PCA task processors. """ processors = [] for _ in range(self.m_num_proc): processors.append(PcaTaskProcessor(self.m_tasks_queue, self.m_result_queue, self.m_star_reshape, self.m_angles, self.m_scales, self.m_pca_model, self.m_im_shape, self.m_indices, self.m_requirements, self.m_processing_type)) return processors