Source code for pynpoint.util.wavelets

"""
Wrapper utils for the wavelet functions for the mlpy cwt implementation (see continous.py)
"""

import numpy as np

from numba import jit
from typeguard import typechecked
from scipy.special import gamma, hermite
from scipy.signal import medfilt
from statsmodels.robust import mad

from pynpoint.util.continuous import autoscales, cwt, icwt
# from pynpoint.util.continuous import fourier_from_scales


# This function cannot by @typechecked because of a compatibility issue with numba
@jit(cache=True, nopython=True)
def _fast_zeros(soft: bool,
                spectrum: np.ndarray,
                uthresh: float) -> np.ndarray:
    """
    Fast numba method to modify values in the wavelet space by using a hard or soft threshold
    function.

    Parameters
    ----------
    soft : bool
        If True soft the threshold function will be used, otherwise a hard threshold is applied.
    spectrum : numpy.ndarray
        The input 2D wavelet space.
    uthresh : float
        Threshold used by the threshold function.

    Returns
    -------
    numpy.ndarray
        Modified spectrum.
    """

    if soft:
        for i in range(0, spectrum.shape[0], 1):
            for j in range(0, spectrum.shape[1], 1):
                tmp_value = spectrum[i, j].real
                if abs(spectrum[i, j]) > uthresh:
                    spectrum[i, j] = np.sign(tmp_value) * (abs(tmp_value) - uthresh)
                else:
                    spectrum[i, j] = 0
    else:
        for i in range(0, spectrum.shape[0], 1):
            for j in range(0, spectrum.shape[1], 1):
                if abs(spectrum[i, j]) < uthresh:
                    spectrum[i, j] = 0

    return spectrum


[docs]class WaveletAnalysisCapsule: """ Capsule class to process one 1d time series using the CWT and wavelet de-nosing by wavelet shrinkage. """ @typechecked def __init__(self, signal_in: np.ndarray, wavelet_in: str = 'dog', order: int = 2, padding: str = 'none', frequency_resolution: float = 0.5) -> None: """ Parameters ---------- signal_in : numpy.ndarray 1D input signal. wavelet_in : str Wavelet function ('dog' or 'morlet'). order : int Order of the wavelet function. padding : str Padding method ('zero', 'mirror', or 'none'). frequency_resolution : float Wavelet space resolution in scale/frequency. Returns ------- NoneType None """ # save input data self.m_supported_wavelets = ['dog', 'morlet'] # check supported wavelets if wavelet_in not in self.m_supported_wavelets: raise ValueError(f'Wavelet {wavelet_in} is not supported') if wavelet_in == 'dog': self._m_c_reconstructions = {2: 3.5987, 4: 2.4014, 6: 1.9212, 8: 1.6467, 12: 1.3307, 16: 1.1464, 20: 1.0222, 30: 0.8312, 40: 0.7183, 60: 0.5853} elif wavelet_in == 'morlet': self._m_c_reconstructions = {5: 0.9484, 6: 0.7784, 7: 0.6616, 8: 0.5758, 10: 0.4579, 12: 0.3804, 14: 0.3254, 16: 0.2844, 20: 0.2272} self.m_wavelet = wavelet_in if padding not in ['none', 'zero', 'mirror']: raise ValueError('Padding can only be none, zero or mirror') self._m_data = signal_in - np.ones(len(signal_in)) * np.mean(signal_in) self.m_padding = padding self.__pad_signal() self._m_data_size = len(self._m_data) self._m_data_mean = np.mean(signal_in) if order not in self._m_c_reconstructions: raise ValueError('Wavelet ' + str(wavelet_in) + ' does not support order ' + str(order) + ". \n Only orders: " + str(sorted(self._m_c_reconstructions.keys())).strip('[]') + " are supported") self.m_order = order self._m_c_final_reconstruction = self._m_c_reconstructions[order] # create scales for wavelet transform self._m_scales = autoscales(N=self._m_data_size, dt=1, dj=frequency_resolution, wf=wavelet_in, p=order) self._m_number_of_scales = len(self._m_scales) self._m_frequency_resolution = frequency_resolution self.m_spectrum = None # --- functions for reconstruction value @staticmethod @typechecked def _morlet_function(omega0: float, x_in: float) -> np.complex128: """ Returns ------- numpy.complex128 Morlet function. """ return np.pi**(-0.25) * np.exp(1j * omega0 * x_in) * np.exp(-x_in**2/2.0) @staticmethod @typechecked def _dog_function(order: int, x_in: float) -> float: """ Returns ------- float DOG function. """ p_hpoly = hermite(order)[int(x_in / np.power(2, 0.5))] herm = p_hpoly / (np.power(2, float(order) / 2)) return ((-1)**(order+1)) / np.sqrt(gamma(order + 0.5)) * herm @typechecked def __pad_signal(self) -> None: """ Returns ------- NoneType None """ padding_length = int(len(self._m_data) * 0.5) if self.m_padding == 'zero': new_data = np.append(self._m_data, np.zeros(padding_length, dtype=np.float64)) self._m_data = np.append(np.zeros(padding_length, dtype=np.float64), new_data) elif self.m_padding == 'mirror': left_half_signal = self._m_data[:padding_length] right_half_signal = self._m_data[padding_length:] new_data = np.append(self._m_data, right_half_signal[::-1]) self._m_data = np.append(left_half_signal[::-1], new_data) @typechecked def __compute_reconstruction_factor(self) -> float: """ Computes the reconstruction factor. Returns ------- float Reconstruction factor. """ freq_res = self._m_frequency_resolution wavelet = self.m_wavelet order = self.m_order if wavelet == 'morlet': zero_function = self._morlet_function(order, 0) else: zero_function = self._dog_function(order, 0) c_delta = self._m_c_final_reconstruction reconstruction_factor = freq_res/(c_delta * zero_function) return reconstruction_factor.real
[docs] @typechecked def compute_cwt(self) -> None: """ Compute the wavelet space of the given input signal. Returns ------- NoneType None """ self.m_spectrum = cwt(self._m_data, dt=1, scales=self._m_scales, wf=self.m_wavelet, p=self.m_order)
[docs] @typechecked def update_signal(self) -> None: """ Updates the internal signal by the reconstruction of the current wavelet space. Returns ------- NoneType None """ self._m_data = icwt(self.m_spectrum, scales=self._m_scales) reconstruction_factor = self.__compute_reconstruction_factor() self._m_data *= reconstruction_factor
[docs] @typechecked def denoise_spectrum(self, soft: bool = False) -> None: """ Applies wavelet shrinkage on the current wavelet space (m_spectrum) by either a hard of soft threshold function. Parameters ---------- soft : bool If True a soft threshold is used, hard otherwise. Returns ------- NoneType None """ if self.m_padding != 'none': noise_length_4 = len(self._m_data) // 4 noise_spectrum = self.m_spectrum[0, noise_length_4: (noise_length_4 * 3)].real else: noise_spectrum = self.m_spectrum[0, :].real sigma = mad(noise_spectrum) uthresh = sigma*np.sqrt(2.0*np.log(len(noise_spectrum))) self.m_spectrum = _fast_zeros(soft, self.m_spectrum, uthresh)
[docs] @typechecked def median_filter(self) -> None: """ Applies a median filter on the internal 1d signal. Can be useful for cosmic ray correction after temporal de-noising Returns ------- NoneType None """ self._m_data = medfilt(self._m_data, 19)
[docs] @typechecked def get_signal(self) -> np.ndarray: """ Returns the current version of the 1d signal. Use update_signal() in advance in order to get the current reconstruction of the wavelet space. Removes padded values as well. Returns ------- numpy.ndarray Current version of the 1D signal. """ tmp_data = self._m_data + np.ones(len(self._m_data)) * self._m_data_mean if self.m_padding == 'none': return tmp_data return tmp_data[len(self._m_data) // 4: 3 * (len(self._m_data) // 4)]
# def __transform_period(self, # period): # # tmp_y = fourier_from_scales(self._m_scales, # self.m_wavelet, # self.m_order) # # def __transformation(x): # return np.log2(x + 1) * tmp_y[-1] / np.log2(tmp_y[-1] + 1) # # cutoff_scaled = __transformation(period) # # scale_new = tmp_y[-1] - tmp_y[0] # scale_old = self.m_spectrum.shape[0] # # factor = scale_old / scale_new # cutoff_scaled *= factor # # return cutoff_scaled # ----- plotting functions -------- # def __plot_or_save_spectrum(self): # plt.close() # # plt.figure(figsize=(8, 6)) # plt.subplot(1, 1, 1) # # tmp_y = fourier_from_scales(self._m_scales, # self.m_wavelet, # self.m_order) # # tmp_x = np.arange(0, self._m_data_size + 1, 1) # # scaled_spec = copy.deepcopy(self.m_spectrum.real) # for i, _ in enumerate(scaled_spec): # scaled_spec[i] /= np.sqrt(self._m_scales[i]) # # plt.imshow(abs(scaled_spec), # aspect='auto', # extent=[tmp_x[0], # tmp_x[-1], # tmp_y[0], # tmp_y[-1]], # cmap=plt.get_cmap("gist_ncar"), # origin='lower') # # # COI first part (only for DOG) with padding # # inner_frequency = 2.*np.pi/np.sqrt(self.m_order + 0.5) # coi = np.append(np.zeros(len(tmp_x)/4), # tmp_x[0:len(tmp_x) / 4]) # coi = np.append(coi, # tmp_x[0:len(tmp_x) / 4][::-1]) # coi = np.append(coi, # np.zeros(len(tmp_x) / 4)) # # plt.plot(np.arange(0, len(coi), 1.0), # inner_frequency * coi / np.sqrt(2), # color="white") # # plt.ylim([tmp_y[0], # tmp_y[-1]]) # # plt.fill_between(np.arange(0, len(coi), 1.0), # inner_frequency * coi / np.sqrt(2), # np.ones(len(coi)) * tmp_y[-1], # facecolor="none", # edgecolor='white', # alpha=0.4, # hatch="x") # # plt.yscale('log', basey=2) # plt.ylabel("Period in [s]") # plt.xlabel("Time in [s]") # plt.title("Spectrum computed with CWT using '" + str(self.m_wavelet) + # "' wavelet of order " + str(self.m_order)) # # def plot_spectrum(self): # """ # Shows a plot of the current wavelet space. # :return: None # """ # # self.__plot_or_save_spectrum() # plt.show() # # def save_spectrum(self, # location): # """ # Saves a plot of the current wavelet space to a given location. # :param location: Save location # :type location: str # :return: None # """ # self.__plot_or_save_spectrum() # plt.savefig(location) # plt.close() # # def __plot_or_save_signal(self): # plt.close() # plt.plot(self._m_data) # plt.title("Signal") # plt.ylabel("Value of the function") # plt.xlim([0, self._m_data_size]) # plt.xlabel("Time in [s]") # # def plot_signal(self): # """ # Plot the current signal. # :return: None # """ # self.__plot_or_save_signal() # plt.show() # # def save_signal(self, # location): # """ # Saves a plot of the current signal to a given location. # :param location: Save location # :type location: str # :return: None # """ # self.__plot_or_save_signal() # plt.savefig(location)