import itertools as itt
from typing import Callable, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from ._util import expformat, truncate_colormap
__all__ = ["InfoPlotter"]
[docs]
class InfoPlotter:
line_formatter: Callable
param_formatter: Callable
math_mode: bool
def __init__(self, line_formatter: Callable, param_formatter: Callable):
self.line_formatter = line_formatter
self.param_formatter = param_formatter
c = to_rgba("tab:blue")
self.default_color = (c[0], c[1], c[2], 0.6)
c = to_rgba("tab:orange")
self.alt_color = (c[0], c[1], c[2], 0.6)
# Discrete plots
[docs]
def bar(
self,
lines: List[str],
mis: List[float],
errs: Optional[List[Union[float, Tuple[float, float]]]] = None,
colors: Optional[List[str]] = None,
sort: bool = False,
nfirst: Optional[int] = None,
transitions: bool = True,
rotation: int = 90,
fontsize: int = 20,
capsize: int = 8,
barwidth: float = 0.6,
bottom_val: Optional[float] = None,
) -> Axes:
"""
TODO
"""
assert len(mis) == len(lines)
assert sort or (nfirst is None)
ax = plt.gca()
if sort:
indices = np.array(mis).argsort()[::-1]
if nfirst is not None:
indices = indices[:nfirst]
mis = [mis[i] for i in indices]
lines = [lines[i] for i in indices]
if errs is not None:
errs = [errs[i] for i in indices]
barlist = ax.bar(
np.arange(len(mis)),
mis,
width=barwidth,
color=self.default_color,
edgecolor="black",
)
ax.errorbar(
np.arange(len(mis)),
mis,
yerr=errs,
fmt="none",
capsize=capsize,
color="tab:red",
)
if colors is not None:
assert len(colors) == len(mis)
if sort:
colors = [colors[i] for i in indices]
for i, c in enumerate(colors):
barlist[i].set_facecolor(c)
for b in barlist:
b.set_linewidth(1.5)
ha = "center" if rotation % 90 == 0 else "right"
rotation_mode = "default" if rotation % 90 == 0 else "anchor"
ax.set_xticks(np.arange(len(mis)))
ax.set_xticklabels(
[
"$" + self.lines_comb_formatter(l, transition=transitions) + "$"
for l in lines
],
rotation=rotation,
fontsize=fontsize,
ha=ha,
rotation_mode=rotation_mode,
)
plt.yticks(fontsize=fontsize)
if errs is None:
low = np.nanmin(mis)
high = np.nanmax(mis)
else:
low = np.nanmin(np.array(mis) - np.array(errs))
high = np.nanmax(np.array(mis) + np.array(errs))
diff = high - low
frac = 0.1
plt.ylim(
[
low - frac * diff if bottom_val is None else bottom_val,
high + frac * diff,
]
)
ax.set_ylabel("Mutual information (bits)", labelpad=24, fontsize=fontsize)
return ax
[docs]
def matrix(
self,
lines: List[str],
mis: List[List[float]],
show_diag: bool = True,
transitions: bool = True,
) -> Figure:
###
cmap = "OrRd"
###
# fig, ax = plt.subplots(1, 1, figsize = (xscale*6.4, yscale*4.8), dpi=dpi)
ax = plt.gca()
fig = ax.get_figure()
mis = np.array(mis)
mask = np.where(
np.tril(np.ones_like(mis), k=-1 if show_diag else 0), float("nan"), 1.0
)
im = ax.imshow(mask * mis, origin="lower", cmap=cmap)
cbar = fig.colorbar(im)
cbar.set_label("Mutual information (bits)", labelpad=30, rotation=270)
ax.set_xticks(np.arange(mis.shape[0]))
ax.set_yticks(np.arange(mis.shape[0]))
ax.set_xticklabels(
["$" + self.line_formatter(l, transition=transitions) + "$" for l in lines],
rotation=45,
ha="right",
rotation_mode="anchor",
fontsize=10,
)
ax.set_yticklabels(
["$" + self.line_formatter(l, transition=transitions) + "$" for l in lines],
rotation=45,
ha="right",
rotation_mode="anchor",
fontsize=10,
)
return fig
[docs]
def bar_comparison(
self,
lines: List[str],
mis: Dict[str, List[float]],
errs: Optional[Dict[str, List[float]]],
labels: Dict[str, str],
transitions: bool = True,
rotation: int = 90,
bottom_val: Optional[float] = None,
show_legend: bool = False,
fontsize: int = 20,
capsize: int = 8,
barwidth: float = 0.6,
) -> Axes:
"""
TODO
"""
ax = plt.gca()
alt = ["el" in "_".join(l) for l in lines]
idx_default = [i for i in range(len(lines)) if not alt[i]]
idx_alt = [i for i in range(len(lines)) if alt[i]]
keys = list(mis.keys())
for i, key in enumerate(keys):
if i == 0:
barlist_0_default = ax.bar(
idx_default,
[mis[key][i] for i in idx_default],
width=barwidth,
color=self.default_color,
edgecolor="black",
)
barlist_0_alt = ax.bar(
idx_alt,
[mis[key][i] for i in idx_alt],
width=barwidth,
color=self.alt_color,
edgecolor="black",
)
ax.errorbar(
np.arange(len(mis[key])),
mis[key],
yerr=errs[key],
fmt="none",
capsize=capsize,
color="tab:red",
linewidth=1.5,
)
else:
barlist_1 = ax.bar(
list(range(len(mis[key]))),
mis[key],
width=barwidth,
color="none",
label=labels[key],
edgecolor="black",
hatch="/",
)
ax.errorbar(
np.arange(len(mis[key])),
mis[key],
yerr=errs[key],
fmt="none",
capsize=capsize,
color="black",
)
for barlist in (barlist_0_default, barlist_0_alt, barlist_1):
for b in barlist:
b.set_linewidth(1.5)
if bottom_val is not None:
plt.ylim([bottom_val, None])
ha = "center" if rotation % 90 == 0 else "right"
rotation_mode = "default" if rotation % 90 == 0 else "anchor"
ax.set_xticks(np.arange(len(lines)))
ax.set_xticklabels(
[
"$" + self.lines_comb_formatter(l, transition=transitions) + "$"
for l in lines
],
rotation=rotation,
fontsize=fontsize,
ha=ha,
rotation_mode=rotation_mode,
)
plt.yticks(fontsize=fontsize)
ax.set_ylabel("Mutual information (bits)", labelpad=24, fontsize=fontsize)
if errs is None:
low = min([np.nanmin(mis[name]) for name in mis])
high = max([np.nanmax(mis[name]) for name in mis])
else:
low = min(
[np.nanmin(np.array(mis[name]) - np.array(errs[name])) for name in mis]
)
high = max(
[np.nanmax(np.array(mis[name]) + np.array(errs[name])) for name in mis]
)
diff = high - low
frac = 0.1
plt.ylim(
[
low - frac * diff if bottom_val is None else bottom_val,
high + frac * diff,
]
)
from matplotlib.collections import PatchCollection
class MulticolorPatch(object):
def __init__(self, colors):
self.colors = colors
class MulticolorPatchHandler(object):
def legend_artist(self, legend, orig_handle, fontsize, handlebox):
width, height = handlebox.width, handlebox.height
patches = []
for i, c in enumerate(orig_handle.colors):
patches.append(
plt.Rectangle(
[
width / len(orig_handle.colors) * i
- handlebox.xdescent,
-handlebox.ydescent,
],
width / len(orig_handle.colors),
height,
facecolor=c,
edgecolor="black",
linewidth=1.5,
)
)
patch = PatchCollection(patches, match_original=True)
handlebox.add_artist(patch)
return patch
h = []
h.append(MulticolorPatch([self.default_color, self.alt_color]))
h.append(barlist_1)
if show_legend:
ax.legend(
h,
list(labels.values()),
fontsize=fontsize,
loc="upper right",
handler_map={MulticolorPatch: MulticolorPatchHandler()},
)
return ax
[docs]
def summary_1d(
self,
parameters: Tuple[str, ...],
regimes: Dict[str, Dict[str, Tuple]],
best_lines: List[Tuple[str, ...]],
confidences: List[float],
) -> Figure:
"""
Plot the summary of the most informative lines. The constraint is on a single parameter.
`parameter` is the set of physical parameter to estimate
Format (example): ('g0',)
`regimes` contains the bounds for all subregimes
Format (example): {'av': {'1': [1, 2], '2': [2, None]}}
`best_lines` contains a
Format (example): [('13co10', 'c18o10'), ('n2hp10')]
`confidence` contains the probabilities for the lines in `best_lines` to be the best.
Format (example): [(line1, line2), (line3)]
"""
###
xscale = 1.2
yscale = 1.0
dpi = 200
###
fig, ax = plt.subplots(
1, 1, figsize=(xscale * 6.4, 0.5 * yscale * 4.8), dpi=dpi
)
# Checking
if isinstance(parameters, str):
parameters = (parameters,)
# Plot grid
param_regime = list(regimes.keys())[0]
x = []
for val in regimes[param_regime].values():
if val is None or val[0] is None:
continue
ax.axvline(len(x) + 1, color="black")
if param_regime in ["g0"]: # TODO
x.append(f"${expformat(val[0])}$")
else:
x.append(f"${val[0]}$")
if val[1] is None:
x.append("$+\\infty$")
# Static settings
fontsizes = {1: 13, 2: 10, 3: 10, 4: 10}
# Plot names and confidences
cmap = plt.get_cmap("gist_rainbow")
subcmap = truncate_colormap(cmap, 0.0, 0.35)
for i, (l, c) in enumerate(zip(best_lines, confidences), 1):
if l is not None:
if isinstance(l, str):
l = (l,)
c = (c,)
l = list(l)
c = list(c)
sign = [None] * len(l)
ax.add_patch(Rectangle((i, 0), 1, 1, color=subcmap(c[0]), alpha=0.4))
for k, _ in enumerate(c):
_c = 100 * c[k]
if _c > 99.9:
_c, _sign = 99.9, ">"
elif _c < 0.1:
_c, _sign = 0.1, "<"
else:
_sign = "="
c[k], sign[k] = _c, _sign
ax.text(
i + 0.5,
0.5,
"\n\n".join(
[
f"${self.lines_comb_formatter(_l, transition=True)}$\n$p {_sign} {_c:.1f}\%$"
for _l, _c, _sign in zip(l, c, sign)
]
),
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsizes[len(l)],
)
else:
ax.add_patch(Rectangle((i, 0), 1, 1, color="gray", alpha=0.6))
ax.add_patch(Rectangle((i, 0), 1, 1, fill=False, hatch="//"))
# Settings
ax.set_xticks(np.arange(1, len(x) + 1))
ax.set_yticks([])
ax.set_xticklabels(x)
ax.set_xlabel("$" + self.param_formatter(param_regime) + "$", labelpad=10)
ax.set_xlim([1, len(x)])
ax.set_ylim([0, 1])
return fig
[docs]
def summary_2d(
self,
parameters: Tuple[str, ...],
regimes: Dict[str, Dict[str, Tuple]],
best_lines: List[List[Tuple[str, ...]]],
confidences: List[List[float]],
):
###
xscale = 1.2
yscale = 1.0
dpi = 200
###
fig, ax = plt.subplots(1, 1, figsize=(xscale * 6.4, yscale * 4.8), dpi=dpi)
# Checking
if isinstance(parameters, str):
parameters = (parameters,)
# Plot grid
param_regime_1, param_regime_2 = list(regimes.keys())[0:2]
x, y = [], []
for val in regimes[param_regime_1].values():
if val is None or val[0] is None:
continue
ax.axvline(len(x) + 1, color="black")
x.append(f"${val[0]}$")
if val[1] is None:
x.append("$+\\infty$")
for val in regimes[param_regime_2].values():
if val is None or val[0] is None:
continue
ax.axhline(len(y) + 1, color="black")
y.append(f"${expformat(val[0])}$")
if val[1] is None:
y.append("$+\\infty$")
# Static settings
coords = {
1: [(0.5, 0.5)],
2: [(0.5, 0.7), (0.5, 0.3)],
3: [(0.5, 0.7), (0.25, 0.3), (0.75, 0.3)],
4: [(0.25, 0.7), (0.75, 0.7), (0.25, 0.3), (0.75, 0.3)],
}
fontsizes = {1: 13, 2: 10, 3: 8, 4: 8}
# Plot names and confidences
cmap = plt.get_cmap("gist_rainbow")
subcmap = truncate_colormap(cmap, 0.0, 0.35)
for i, j in itt.product(range(len(best_lines)), range(len(best_lines[0]))):
l = best_lines[i][j]
c = confidences[i][j]
if l is not None:
if isinstance(l, str):
l = (l,)
c = (c,)
ax.add_patch(
Rectangle((i + 1, j + 1), 1, 1, color=subcmap(c[0]), alpha=0.4)
)
for k, _ in enumerate(l):
_c = 100 * c[k]
if _c > 99.9:
_c, _sign = 99.9, ">"
elif _c < 0.1:
_c, _sign = 0.1, "<"
else:
_sign = "="
_l = l[k]
i0, j0 = coords[len(l)][k]
ax.text(
i + 1 + i0,
j + 1 + j0,
f"${self.lines_comb_formatter(_l, transition=True)}$\n$p {_sign} {_c:.1f}\%$",
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsizes[len(l)],
)
else:
ax.add_patch(Rectangle((i + 1, j + 1), 1, 1, color="gray", alpha=0.6))
ax.add_patch(Rectangle((i + 1, j + 1), 1, 1, fill=False, hatch="//"))
# Settings
ax.set_xticks(np.arange(1, len(x) + 1))
ax.set_yticks(np.arange(1, len(y) + 1))
ax.set_xticklabels(x)
ax.set_yticklabels(y)
ax.set_xlabel("$" + self.param_formatter(param_regime_1) + "$", labelpad=10)
ax.set_ylabel("$" + self.param_formatter(param_regime_2) + "$", labelpad=10)
ax.set_xlim([1, len(x)])
ax.set_ylim([1, len(y)])
return fig
# Continuous plots
[docs]
def profile(self):
raise NotImplementedError("TODO")
[docs]
def profiles_summary(self):
raise NotImplementedError("TODO")
[docs]
def map(
self,
xticks: np.ndarray,
yticks: np.ndarray,
mat: np.ndarray,
vmax: Optional[float] = None,
cmap: str = "jet",
paramx: Optional[str] = None,
paramy: Optional[str] = None,
):
ax = plt.gca()
X, Y = np.meshgrid(xticks, yticks)
im = ax.pcolor(X, Y, mat, cmap=cmap, vmin=0, vmax=vmax)
cbar = plt.colorbar(im, ax=ax) # fig.colorbar(...)
cbar.set_label("Amount of information (bits)", labelpad=10)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel(f"${self.param_formatter(paramx)}$")
ax.set_ylabel(f"${self.param_formatter(paramy)}$")
return ax
[docs]
def maps_summary(self):
raise NotImplementedError("TODO")
# Helpers