Source code for pynpoint.util.psf

"""
Functions for PSF subtraction.
"""

from typing import Optional, Union, Tuple

import numpy as np

from scipy.ndimage import rotate
from sklearn.decomposition import PCA
from typeguard import typechecked

from pynpoint.util.image import scale_image, shift_image


[docs]@typechecked def pca_psf_subtraction(images: np.ndarray, angles: Optional[np.ndarray], pca_number: Union[int, np.int32, np.int64], scales: Optional[np.ndarray] = None, pca_sklearn: Optional[PCA] = None, im_shape: Optional[tuple] = None, indices: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: """ Function for PSF subtraction with PCA. Parameters ---------- images : np.ndarray Stack of images. Also used as reference images if ```pca_sklearn``` is set to None. The data should have the original 3D shape if ``pca_sklearn`` is set to None or it should be in a 2D reshaped format if ``pca_sklearn`` is not set to None. angles : np.ndarray Parallactic angles (deg). pca_number : int Number of principal components. scales : np.ndarray, None Scaling factors for SDI. Not used if set to None. pca_sklearn : sklearn.decomposition.pca.PCA, None PCA object with the principal components. im_shape : tuple(int, int, int), None The original 3D shape of the stack with images. Only required if ``pca_sklearn`` is not set to None. indices : np.ndarray, None Array with the indices of the pixels that are used for the PSF subtraction. All pixels are used if set to None. Returns ------- np.ndarray Residuals of the PSF subtraction. np.ndarray Derotated residuals of the PSF subtraction. """ if pca_sklearn is None: # Create a PCA object if not provided as argument pca_sklearn = PCA(n_components=pca_number, svd_solver='arpack') # The 3D shape of the array with images im_shape = images.shape if indices is None: # Select the first image and get the unmasked image indices im_star = images[0, ].reshape(-1) indices = np.where(im_star != 0.)[0] # Reshape the images and select the unmasked pixels im_reshape = images.reshape(im_shape[0], im_shape[1]*im_shape[2]) im_reshape = im_reshape[:, indices] # Subtract the mean image # This is also done by sklearn.decomposition.PCA.fit() im_reshape -= np.mean(im_reshape, axis=0) # Fit the principal components pca_sklearn.fit(im_reshape) else: # If the PCA object is already there then so are the reshaped data im_reshape = np.copy(images) # Project the data on the principal components # Note that this is the same as sklearn.decomposition.PCA.transform() # It is harcoded because the number of components has been adjusted pca_rep = np.matmul(pca_sklearn.components_[:pca_number], im_reshape.T) # The zeros are added with vstack to account for the components that have not been used for the # transformation to the lower-dimensional space, while they were initiated with the PCA object. # Since inverse_transform uses the number of initial components, the zeros are added for # components > pca_number. These components do not impact the inverse transformation. zeros = np.zeros((pca_sklearn.n_components - pca_number, im_reshape.shape[0])) pca_rep = np.vstack((pca_rep, zeros)).T # Transform the data back to the original space psf_model = pca_sklearn.inverse_transform(pca_rep) # Create an array with the original shape residuals = np.zeros((im_shape[0], im_shape[1]*im_shape[2])) # Select all pixel indices if set to None if indices is None: indices = np.arange(0, im_reshape.shape[1], 1) # Subtract the PSF model residuals[:, indices] = im_reshape - psf_model # Reshape the residuals to the original shape residuals = residuals.reshape(im_shape) # ----------- back scale images scal_cor = np.zeros(residuals.shape) if scales is not None: # check if the number of parang is equal to the number of images if residuals.shape[0] != scales.shape[0]: raise ValueError(f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of wavelengths ({scales.shape[0]}).') for i, _ in enumerate(scales): # rescaling the images swaps = scale_image(residuals[i, ], 1/scales[i], 1/scales[i]) npix_del = scal_cor.shape[-1] - swaps.shape[-1] if npix_del == 0: scal_cor[i, ] = swaps else: if npix_del % 2 == 0: npix_del_a = int(npix_del/2) npix_del_b = int(npix_del/2) else: npix_del_a = int((npix_del-1)/2) npix_del_b = int((npix_del+1)/2) scal_cor[i, npix_del_a:-npix_del_b, npix_del_a:-npix_del_b] = swaps if npix_del % 2 == 1: scal_cor[i, ] = shift_image(scal_cor[i, ], (0.5, 0.5), interpolation='spline') else: scal_cor = residuals res_rot = np.zeros(residuals.shape) if angles is not None: # Check if the number of parang is equal to the number of images if residuals.shape[0] != angles.shape[0]: raise ValueError(f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of parallactic angles ({angles.shape[0]}).') for j, item in enumerate(angles): res_rot[j, ] = rotate(scal_cor[j, ], item, reshape=False) else: res_rot = scal_cor return scal_cor, res_rot