import numpy as np
from scipy.signal import hilbert
from PyEMD.compact import filt6, pade6
# Visualisation is an optional module. To minimise installation, `matplotlib` is not added
# by default. Please install extras with `pip install -r requirement-extra.txt`.
try:
import pylab as plt
except ImportError:
pass
[docs]class Visualisation(object):
"""Simple visualisation helper.
This class is for quick and simple result visualisation.
"""
PLOT_WIDTH = 6
PLOT_HEIGHT_PER_IMF = 1.5
def __init__(self, emd_instance=None):
self.emd_instance = emd_instance
self.imfs = None
self.residue = None
if emd_instance is not None:
self.imfs, self.residue = self.emd_instance.get_imfs_and_residue()
def _check_imfs(self, imfs, residue, include_residue):
"""Checks for passed imfs and residue."""
imfs = imfs if imfs is not None else self.imfs
residue = residue if residue is not None else self.residue
if imfs is None:
raise AttributeError("No imfs passed to plot")
if include_residue and residue is None:
raise AttributeError("Requested to plot residue but no residue provided")
return imfs, residue
[docs] def plot_imfs(self, imfs=None, residue=None, t=None, include_residue=True):
"""Plots and shows all IMFs.
All parameters are optional since the `emd` object could have been passed when instantiating this object.
The residual is an optional and can be excluded by setting `include_residue=False`.
"""
imfs, residue = self._check_imfs(imfs, residue, include_residue)
num_rows, t_length = imfs.shape
num_rows += include_residue is True
t = t if t is not None else range(t_length)
fig, axes = plt.subplots(num_rows, 1, figsize=(self.PLOT_WIDTH, num_rows * self.PLOT_HEIGHT_PER_IMF))
if num_rows == 1:
axes = list(axes)
axes[0].set_title("Time series")
for num, imf in enumerate(imfs):
ax = axes[num]
ax.plot(t, imf)
ax.set_ylabel("IMF " + str(num + 1))
if include_residue:
ax = axes[-1]
ax.plot(t, residue)
ax.set_ylabel("Res")
# Making the layout a bit more pleasant to the eye
plt.tight_layout()
[docs] def plot_instant_freq(self, t, imfs=None, order=False, alpha=None):
"""Plots and shows instantaneous frequencies for all provided imfs.
The necessary parameter is `t` which is the time array used to compute the EMD.
One should pass `imfs` if no `emd` instances is passed when creating the Visualisation object.
Parameters
----------
order : bool (default: False)
Represents whether the finite difference scheme is
low-order (1st order forward scheme) or high-order (6th order
compact scheme). The default value is False (low-order)
alpha : float (default: None)
Filter intensity. Default value is None, which
is equivalent to `alpha` = 0.5, meaning that no filter is applied.
The `alpha` values must be in between -0.5 (fully active) and 0.5
(no filter).
"""
if alpha is not None:
assert -0.5 < alpha < 0.5, "`alpha` must be in between -0.5 and 0.5"
imfs, _ = self._check_imfs(imfs, None, False)
num_rows = imfs.shape[0]
imfs_inst_freqs = self._calc_inst_freq(imfs, t, order=order, alpha=alpha)
fig, axes = plt.subplots(num_rows, 1, figsize=(self.PLOT_WIDTH, num_rows * self.PLOT_HEIGHT_PER_IMF))
if num_rows == 1:
axes = fig.axes
axes[0].set_title("Instantaneous frequency")
for num, imf_inst_freq in enumerate(imfs_inst_freqs):
ax = axes[num]
ax.plot(t, imf_inst_freq)
ax.set_ylabel("IMF {} [Hz]".format(num + 1))
# Making the layout a bit more pleasant to the eye
plt.tight_layout()
def _calc_inst_phase(self, sig, alpha):
"""Extract analytical signal through the Hilbert Transform."""
analytic_signal = hilbert(sig) # Apply Hilbert transform to each row
if alpha is not None:
assert -0.5 < alpha < 0.5, "`alpha` must be in between -0.5 and 0.5"
real_part = np.array([filt6(row.real, alpha) for row in analytic_signal])
imag_part = np.array([filt6(row.imag, alpha) for row in analytic_signal])
analytic_signal = real_part + 1j * imag_part
phase = np.unwrap(np.angle(analytic_signal)) # Compute angle between img and real
if alpha is not None:
phase = np.array([filt6(row, alpha) for row in phase]) # Filter phase
return phase
def _calc_inst_freq(self, sig, t, order, alpha):
"""Extracts instantaneous frequency through the Hilbert Transform."""
inst_phase = self._calc_inst_phase(sig, alpha=alpha)
if order is False:
inst_freqs = np.diff(inst_phase) / (2 * np.pi * (t[1] - t[0]))
inst_freqs = np.concatenate((inst_freqs, inst_freqs[:, -1].reshape(inst_freqs[:, -1].shape[0], 1)), axis=1)
else:
inst_freqs = [pade6(row, t[1] - t[0]) / (2.0 * np.pi) for row in inst_phase]
if alpha is None:
return np.array(inst_freqs)
else:
return np.array([filt6(row, alpha) for row in inst_freqs]) # Filter freqs
def show(self):
plt.show()
if __name__ == "__main__":
from PyEMD import EMD
# Simple signal example
t = np.arange(0, 3, 0.01)
S = np.sin(13 * t + 0.2 * t**1.4) - np.cos(3 * t)
emd = EMD()
emd.emd(S)
imfs, res = emd.get_imfs_and_residue()
# Initiate visualisation with emd instance
vis = Visualisation(emd)
# Create a plot with all IMFs and residue
vis.plot_imfs(imfs=imfs, residue=res, t=t, include_residue=True)
# Create a plot with instantaneous frequency of all IMFs
vis.plot_instant_freq(t, imfs=imfs)
# Show both plots
vis.show()