"""
Functions for combining the residuals of the PSF subtraction.
"""
from typing import Optional
import numpy as np
from typeguard import typechecked
from scipy.ndimage import rotate
[docs]@typechecked
def combine_residuals(method: str,
res_rot: np.ndarray,
residuals: Optional[np.ndarray] = None,
angles: Optional[np.ndarray] = None) -> np.ndarray:
"""
Wavelength wrapper for the combine_residual function. Produces an array with either 1
or number of wavelengths sized array.
Parameters
----------
method : str
Method used for combining the residuals ('mean', 'median', 'weighted', or 'clipped').
res_rot : np.ndarray
Derotated residuals of the PSF subtraction (3D).
residuals : np.ndarray, None
Non-derotated residuals of the PSF subtraction (3D). Only required for the noise-weighted
residuals.
angles : np.ndarray, None
Derotation angles (deg). Only required for the noise-weighted residuals.
Returns
-------
np.ndarray
Collapsed residuals (3D).
"""
if res_rot.ndim == 3:
output = _residuals(method=method,
res_rot=np.asarray(res_rot),
residuals=residuals,
angles=angles)
if res_rot.ndim == 4:
output = np.zeros((res_rot.shape[0], res_rot.shape[2], res_rot.shape[3]))
for i in range(res_rot.shape[0]):
if residuals is None:
output[i, ] = _residuals(method=method,
res_rot=res_rot[i, ],
residuals=residuals,
angles=angles)[0]
else:
output[i, ] = _residuals(method=method,
res_rot=res_rot[i, ],
residuals=residuals[i, ],
angles=angles)[0]
return output
@typechecked
def _residuals(method: str,
res_rot: np.ndarray,
residuals: Optional[np.ndarray] = None,
angles: Optional[np.ndarray] = None) -> np.ndarray:
"""
Function for combining the derotated residuals of the PSF subtraction.
Parameters
----------
method : str
Method used for combining the residuals ('mean', 'median', 'weighted', or 'clipped').
res_rot : np.ndarray
Derotated residuals of the PSF subtraction (3D).
residuals : np.ndarray, None
Non-derotated residuals of the PSF subtraction (3D). Only required for the noise-weighted
residuals.
angles : np.ndarray, None
Derotation angles (deg). Only required for the noise-weighted residuals.
Returns
-------
np.ndarray
Combined residuals (3D).
"""
if method == 'mean':
stack = np.mean(res_rot, axis=0)
elif method == 'median':
stack = np.median(res_rot, axis=0)
elif method == 'weighted':
tmp_res_var = np.var(residuals, axis=0)
res_repeat = np.repeat(tmp_res_var[np.newaxis, :, :],
repeats=residuals.shape[0],
axis=0)
res_var = np.zeros(res_repeat.shape)
for j, angle in enumerate(angles):
# scipy.ndimage.rotate rotates in clockwise direction for positive angles
res_var[j, ] = rotate(input=res_repeat[j, ],
angle=angle,
reshape=False)
weight1 = np.divide(res_rot,
res_var,
out=np.zeros_like(res_var),
where=(np.abs(res_var) > 1e-100) & (res_var != np.nan))
weight2 = np.divide(1.,
res_var,
out=np.zeros_like(res_var),
where=(np.abs(res_var) > 1e-100) & (res_var != np.nan))
sum1 = np.sum(weight1, axis=0)
sum2 = np.sum(weight2, axis=0)
stack = np.divide(sum1,
sum2,
out=np.zeros_like(sum2),
where=(np.abs(sum2) > 1e-100) & (sum2 != np.nan))
elif method == 'clipped':
stack = np.zeros(res_rot.shape[-2:])
for i in range(stack.shape[0]):
for j in range(stack.shape[1]):
pix_line = res_rot[:, i, j]
if np.var(pix_line) > 0.:
no_mean = pix_line - np.mean(pix_line)
part1 = no_mean.compress((no_mean < 3.*np.sqrt(np.var(no_mean))).flat)
part2 = part1.compress((part1 > -3.*np.sqrt(np.var(no_mean))).flat)
stack[i, j] = np.mean(pix_line) + np.mean(part2)
return stack[np.newaxis, ...]