"""Helper functions."""
from copy import deepcopy
from warnings import warn
import matplotlib.projections as proj
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.image import NonUniformImage
__author__ = "Deepansh J. Srivastava"
__email__ = "srivastava.89@osu.edu"
scalar = ["scalar", "vector_1", "pixel_1", "matrix_1_1", "symmetric_matrix_1"]
def _get_label_from_dv(dv, i):
"""Return label along with the unit of the dependent variable
Args:
dv: DependentVariable object.
i: integer counter.
"""
name, unit = dv.name, dv.unit
name = name if name != "" else str(i)
label = f"{name} / ({unit})" if unit != "" else name
return label
[docs]class CSDMAxes(plt.Axes):
"""A custom CSDM data plot axes."""
name = "csdm"
[docs] def plot(self, csdm, *args, **kwargs):
"""Generate a figure axes using the `plot` method from the matplotlib library.
Apply to all 1D datasets with single-component dependent-variables. For
multiple dependent variables, the data from individual dependent-variables is
plotted on the same figure.
Args:
csdm: A CSDM object of a one-dimensional dataset.
kwargs: Additional keyword arguments for the matplotlib plot() method.
Example
-------
>>> ax = plt.subplot(projection='csdm') # doctest: +SKIP
>>> ax.plot(csdm_object) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
if csdm.__class__.__name__ != "CSDM":
return super().plot(csdm, *args, **kwargs)
return self._call_1D(csdm, "plot", *args, **kwargs)
[docs] def scatter(self, csdm, *args, **kwargs):
"""Generate a figure axes using the `scatter` method from the matplotlib
library.
Apply to all 1D datasets with single-component dependent-variables. For
multiple dependent variables, the data from individual dependent-variables is
plotted on the same figure.
Args:
csdm: A CSDM object of a one-dimensional dataset.
kwargs: Additional keyword arguments for the matplotlib plot() method.
Example
-------
>>> ax = plt.subplot(projection='csdm') # doctest: +SKIP
>>> ax.scatter(csdm_object) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
if csdm.__class__.__name__ != "CSDM":
return super().scatter(csdm, *args, **kwargs)
return self._call_1D(csdm, "scatter", *args, **kwargs)
[docs] def imshow(self, csdm, origin="lower", *args, **kwargs):
"""Generate a figure axes using the `imshow` method from the matplotlib library.
Apply to all 2D datasets with either single-component (scalar),
three-components (pixel_3), or four-components (pixel_4) dependent-variables.
For single-component (scalar) dependent-variable, a colormap image is produced.
For three-components (pixel_3) dependent-variable, an RGB image is produced.
For four-components (pixel_4) dependent-variable, an RGBA image is produced.
For multiple dependent variables, the data from individual dependent-variables
is plotted on the same figure.
Args:
csdm: A CSDM object of a two-dimensional dataset with scalar, pixel_3, or
pixel_4 quantity_type dependent variable.
origin: The matplotlib `origin` argument. In matplotlib, the default is
'upper'. In csdmpy, however, the default to 'lower'.
kwargs: Additional keyword arguments for the matplotlib imshow() method.
Example
-------
>>> ax = plt.subplot(projection='csdm') # doctest: +SKIP
>>> ax.imshow(csdm_object) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
if csdm.__class__.__name__ != "CSDM":
return super().imshow(csdm, *args, **kwargs)
x = csdm.dimensions
if x[0].type == "linear" and x[1].type == "linear":
return self._call_uniform_2D_image(csdm, origin=origin, *args, **kwargs)
[docs] def contour(self, csdm, *args, **kwargs):
"""Generate a figure axes using the `contour` method from the matplotlib
library.
Apply to all 2D datasets with a single-component (scalar) dependent-variables.
For multiple dependent variables, the data from individual dependent-variables
is plotted on the same figure.
Args:
csdm: A CSDM object of a two-dimensional dataset with scalar dependent
variable.
kwargs: Additional keyword arguments for the matplotlib contour() method.
Example
-------
>>> ax = plt.subplot(projection='csdm') # doctest: +SKIP
>>> ax.contour(csdm_object) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
if csdm.__class__.__name__ != "CSDM":
return super().contour(csdm, *args, **kwargs)
x = csdm.dimensions
if x[0].type == "linear" and x[1].type == "linear":
return self._call_uniform_2D_contour(csdm, "contour", *args, **kwargs)
[docs] def contourf(self, csdm, *args, **kwargs):
"""Generate a figure axes using the `contourf` method from the matplotlib
library.
Apply to all 2D datasets with a single-component (scalar) dependent-variables.
For multiple dependent variables, the data from individual dependent-variables
is plotted on the same figure.
Args:
csdm: A CSDM object of a two-dimensional dataset with scalar dependent
variable.
kwargs: Additional keyword arguments for the matplotlib contourf() method.
Example
-------
>>> ax = plt.subplot(projection='csdm') # doctest: +SKIP
>>> ax.contourf(csdm_object) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
if csdm.__class__.__name__ != "CSDM":
return super().contour(csdm, *args, **kwargs)
x = csdm.dimensions
if x[0].type == "linear" and x[1].type == "linear":
return self._call_uniform_2D_contour(csdm, "contourf", *args, **kwargs)
def _call_1D(self, csdm, fn, *args, **kwargs):
_check_1D_dataset(csdm)
x = csdm.dimensions
z = csdm.split()
one = True if len(z) == 1 else False
legend = False
r_plt = None
for i, item in enumerate(z):
x_, y_ = item.to_list()
# dv will always be at index 0 because we called the object.split() before.
dv = item.dependent_variables[0]
kwargs_ = deepcopy(kwargs)
# add a default label if not provided by the user.
if "label" not in kwargs_.keys():
kwargs_["label"] = dv.name if one else _get_label_from_dv(dv, i)
if kwargs_["label"] != "":
legend = True
if fn == "plot":
r_plt = super().plot(x_, y_, *args, **kwargs_)
if fn == "scatter":
r_plt = super().scatter(x_, y_, *args, **kwargs_)
self.set_xlim(x[0].coordinates.value.min(), x[0].coordinates.value.max())
self.set_xlabel(x[0].axis_label)
ylabel = dv.axis_label[0] if one else "dimensionless"
self.set_ylabel(ylabel)
# self.grid(color="gray", linestyle="--", linewidth=0.5)
if legend:
self.legend()
if r_plt is None:
raise NotImplementedError("Cannot plot dataset")
return r_plt
def _call_uniform_2D_contour(self, csdm, fn, *args, **kwargs):
_check_2D_scalar_dataset(csdm)
kw_keys = kwargs.keys()
# set extent
x = csdm.dimensions
x0, x1 = x[0].coordinates.value, x[1].coordinates.value
# add cmap for multiple dependent variables.
cmaps_bool = True if "cmaps" in kw_keys else False
cmaps = kwargs.pop("cmaps") if cmaps_bool else None
one = True if len(csdm.dependent_variables) == 1 else False
r_plt = None
for i, dv in enumerate(csdm.dependent_variables):
y = dv.components
if dv.quantity_type in ["scalar", "vector_1", "pixel_1"]:
if cmaps_bool:
kwargs["cmap"] = cmaps[i]
if fn == "contour":
r_plt = super().contour(x0, x1, y[0], *args, **kwargs)
if fn == "contourf":
r_plt = super().contourf(x0, x1, y[0], *args, **kwargs)
self.set_xlim(x0.min(), x0.max())
self.set_ylim(x1.min(), x1.max())
self.set_xlabel(x[0].axis_label)
self.set_ylabel(x[1].axis_label)
if one:
self.set_title(dv.name)
if r_plt is None:
raise NotImplementedError("Cannot plot dataset")
return r_plt
def _call_uniform_2D_image(self, csdm, *args, **kwargs):
_check_2D_scalar_and_pixel_dataset(csdm)
kw_keys = kwargs.keys()
# set extent
x = csdm.dimensions
x0, x1 = x[0].coordinates.value, x[1].coordinates.value
extent = [x0[0], x0[-1], x1[0], x1[-1]]
if kwargs["origin"] == "upper":
extent = [x0[0], x0[-1], x1[-1], x1[0]]
if "extent" not in kw_keys:
kwargs["extent"] = extent
# add cmap for multiple dependent variables.
cmaps_bool = True if "cmaps" in kw_keys else False
cmaps = kwargs.pop("cmaps") if cmaps_bool else None
one = True if len(csdm.dependent_variables) == 1 else False
r_plt = None
for i, dv in enumerate(csdm.dependent_variables):
y = dv.components
if dv.quantity_type in ["scalar", "vector_1", "pixel_1"]:
if cmaps_bool:
kwargs["cmap"] = cmaps[i]
r_plt = super().imshow(y[0], *args, **kwargs)
if dv.quantity_type == "pixel_3":
r_plt = super().imshow(np.moveaxis(y.copy(), 0, -1), *args, **kwargs)
if dv.quantity_type == "pixel_4":
r_plt = super().imshow(np.moveaxis(y.copy(), 0, -1), *args, **kwargs)
self.set_xlabel(x[0].axis_label)
self.set_ylabel(x[1].axis_label)
if one:
self.set_title(dv.name)
if r_plt is None:
raise NotImplementedError("Cannot plot dataset")
return r_plt
try:
proj.register_projection(CSDMAxes)
except NameError:
pass
def _check_1D_dataset(csdm):
x, y = csdm.dimensions, csdm.dependent_variables
message = (
"The function requires a 1D dataset with single-component dependent "
"variables. For multiple dependent-variables, the data from all the "
"dependent variables are plotted on the same figure."
)
if len(x) != 1:
raise Exception(message)
for y_ in y:
if len(y_.components) != 1:
raise Exception(message)
def _check_2D_scalar_and_pixel_dataset(csdm):
x, y = csdm.dimensions, csdm.dependent_variables
message = (
"The function requires a 2D dataset with a single-component (scalar), "
"three components (pixel_3), or four components (pixel_4) dependent "
"variables. The pixel_3 produces an RGB image while pixel_4, a RGBA image."
)
if len(x) != 2:
raise Exception(message)
for y_ in y:
if len(y_.components) not in [1, 3, 4]:
raise Exception(message)
def _check_2D_scalar_dataset(csdm):
x, y = csdm.dimensions, csdm.dependent_variables
message = (
"The function requires a 2D dataset with a single-component (scalar), "
"dependent variables."
)
if len(x) != 2:
raise Exception(message)
for y_ in y:
if len(y_.components) != 1:
raise Exception(message)
# --------- cp plot functions ---------- #
def _preview(data, reverse_axis=None, range_=None, **kwargs):
"""Quick display of the data."""
if reverse_axis is not None:
kwargs["reverse_axis"] = reverse_axis
if range_ is None:
range_ = [[None, None], [None, None]]
x = data.dimensions
y = data.dependent_variables
y_len = len(y)
y_grid = int(y_len / 2) + 1
if len(x) == 0:
raise NotImplementedError(
"Preview of zero dimensional datasets is not implemented."
)
if len(x) > 2:
raise NotImplementedError(
"Preview of three or higher dimensional datasets " "is not implemented."
)
if np.any([x[i].type == "labeled" for i in range(len(x))]):
raise NotImplementedError("Preview of labeled dimensions is not implemented.")
fig = plt.gcf()
if y_len <= 2:
ax = fig.subplots(y_grid)
ax = [[ax]] if y_len == 1 else [ax]
else:
ax = fig.subplots(y_grid, 2)
if len(x) == 1:
one_d_plots(ax, x, y, range_, **kwargs)
if len(x) == 2:
two_d_plots(ax, x, y, range_, **kwargs)
return fig
def one_d_plots(ax, x, y, range_, **kwargs):
"""A collection of possible 1D plots."""
for i, y_item in enumerate(y):
i0 = int(i / 2)
j0 = int(i % 2)
ax_ = ax[i0][j0]
if y_item.quantity_type in scalar:
oneD_scalar(x, y_item, ax_, range_, **kwargs)
if "vector" in y_item.quantity_type:
vector_plot(x, y_item, ax_, range_, **kwargs)
# if "audio" in y_item.quantity_type:
# audio(x, y, i, fig, ax, **kwargs)
def two_d_plots(ax, x, y, range_, **kwargs):
"""A collection of possible 2D plots."""
for i, y_item in enumerate(y):
i0 = int(i / 2)
j0 = int(i % 2)
ax_ = ax[i0][j0]
if y_item.quantity_type == "pixel_3":
warn("This method interprets the `pixel_3` dataset as an RGB image.")
RGB_image(x, y_item, ax_, range_, **kwargs)
if y_item.quantity_type in scalar:
twoD_scalar(x, y_item, ax_, range_, **kwargs)
if "vector" in y_item.quantity_type:
vector_plot(x, y_item, ax_, range_, **kwargs)
def oneD_scalar(x, y, ax, range_, **kwargs):
reverse = [False]
if "reverse_axis" in kwargs.keys():
reverse = kwargs["reverse_axis"]
kwargs.pop("reverse_axis")
components = y.components.shape[0]
for k in range(components):
ax.plot(x[0].coordinates, y.components[k], **kwargs)
ax.set_xlim(x[0].coordinates.value.min(), x[0].coordinates.value.max())
ax.set_xlabel(f"{x[0].axis_label} - 0")
ax.set_ylabel(y.axis_label[0])
ax.set_title(f"{y.name}")
ax.grid(color="gray", linestyle="--", linewidth=0.5)
ax.set_xlim(range_[0])
ax.set_ylim(range_[1])
if reverse[0]:
ax.invert_xaxis()
def twoD_scalar(x, y, ax, range_, **kwargs):
reverse = [False, False]
if "reverse_axis" in kwargs.keys():
reverse = kwargs["reverse_axis"]
kwargs.pop("reverse_axis")
x0 = x[0].coordinates.value
x1 = x[1].coordinates.value
y00 = y.components[0]
extent = [x0[0], x0[-1], x1[0], x1[-1]]
if "extent" not in kwargs.keys():
kwargs["extent"] = extent
if x[0].type == "linear" and x[1].type == "linear":
if "origin" not in kwargs.keys():
kwargs["origin"] = "lower"
if "aspect" not in kwargs.keys():
kwargs["aspect"] = "auto"
cs = ax.imshow(y00, **kwargs)
else:
if "interpolation" not in kwargs.keys():
kwargs["interpolation"] = "nearest"
cs = NonUniformImage(ax, **kwargs)
cs.set_data(x0, x1, y00)
ax.add_artist(cs)
cbar = ax.figure.colorbar(cs, ax=ax)
cbar.ax.minorticks_off()
cbar.set_label(y.axis_label[0])
ax.set_xlim([extent[0], extent[1]])
ax.set_ylim([extent[2], extent[3]])
ax.set_xlabel(f"{x[0].axis_label} - 0")
ax.set_ylabel(f"{x[1].axis_label} - 1")
ax.set_title(f"{y.name}")
ax.grid(color="gray", linestyle="--", linewidth=0.5)
ax.set_xlim(range_[0])
ax.set_ylim(range_[1])
if reverse[0]:
ax.invert_xaxis()
if reverse[1]:
ax.invert_yaxis()
def vector_plot(x, y, ax, range_, **kwargs):
reverse = [False, False]
if "reverse_axis" in kwargs.keys():
reverse = kwargs["reverse_axis"]
kwargs.pop("reverse_axis")
x0 = x[0].coordinates.value
if len(x) == 2:
x1 = x[1].coordinates.value
else:
x1 = np.zeros(1)
x0, x1 = np.meshgrid(x0, x1)
u1 = y.components[0]
v1 = y.components[1]
if "pivot" not in kwargs.keys():
kwargs["pivot"] = "middle"
ax.quiver(x0, x1, u1, v1, **kwargs)
ax.set_xlabel(f"{x[0].axis_label} - 0")
ax.set_xlim(x[0].coordinates.value.min(), x[0].coordinates.value.max())
if len(x) == 2:
ax.set_ylim(x[1].coordinates.value.min(), x[1].coordinates.value.max())
ax.set_ylabel(f"{x[1].axis_label} - 1")
if reverse[1]:
ax.invert_yaxis()
else:
ax.set_ylim([-y.components.max(), y.components.max()])
ax.set_title(f"{y.name}")
ax.grid(color="gray", linestyle="--", linewidth=0.5)
ax.set_xlim(range_[0])
ax.set_ylim(range_[1])
if reverse[0]:
ax.invert_xaxis()
def RGB_image(x, y, ax, range_, **kwargs):
reverse = [False, False]
if "reverse_axis" in kwargs.keys():
reverse = kwargs["reverse_axis"]
kwargs.pop("reverse_axis")
y0 = y.components
ax.imshow(np.moveaxis(y0 / y0.max(), 0, -1), **kwargs)
ax.set_title(f"{y.name}")
ax.set_xlim(range_[0])
ax.set_ylim(range_[1])
if reverse[0]:
ax.invert_xaxis()
if reverse[1]:
ax.invert_yaxis()
# def audio(x, y, i0, fig, ax):
# try:
# SOUND = 1
# import sounddevice as sd
# except ImportError:
# SOUND = 0
# string = (
# "Module 'sounddevice' is not installed. All audio data files will "
# "not be played. To enable audio files, install 'sounddevice' using"
# " 'pip install sounddevice'."
# )
# warn(string)
# plot1D(x, y, i0, ax)
# if SOUND == 1:
# data_max = y[i0].components.max()
# sd.play(0.9 * y[i0].components.T / data_max, 1 / x[0].increment.to("s").value)