import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from pathlib import Path
from .utils import project
_STYLES = ["checkerboard", "mesh", "scatter"]
_CONFIG_DESCR = """\
===============================================================
I. MAIN ARGUMENTS — passed directly to sqplot()
===============================================================
style : rendering style. One of:
'checkerboard' — points coloured in a 2-colour tile pattern
'mesh' — grid lines drawn at 4 levels of density,
'scatter' — plain point cloud; depth-coloured (cmap) in
3-D, black in 2-D.
animate : False → single static PNG.
True → GIF that morphs continuously from the input grid
to the identity grid and back.
save : whether to write the output to disk.
save_path : destination path.
===============================================================
II. EXTRA ARGUMENTS — passed as cfg = {...}
===============================================================
All keys are optional. Only specify what you want to override;
everything else is filled from the defaults shown here.
Layout
figsize : (width, heigth) size of the base figure
Projection
projection : 3-tuple of axis indices, e.g. (0, 1, 2).
Used if ndim >= 3. this selects both
which axes to keep AND in what order
Scale & density
scale_factor : positive integer. bigger -> finer details
pointsize : base size for a single point
linewidth : base width for a mesh lines.
mesh_long_edge : length ratio (relative to the median) above wich an edge will not be ploted in 'mesh' style
Colours
colors_checkerboard : 2 colors for the 'checkerboard' style.
cmap_scatter : colormap e.g. 'plasma', 'coolwarm' used in 'scatter' to encode the depth.
Animation
frames : number of frames in the animation
interval : inter-frame delay in milliseconds (for interactive sessions).
fps : frames per second written (for the saved GIF).
Display
show : if True, calls plt.show().
Will fail if session is not interactive
\n
"""
# =========================================================
# UTILS
# =========================================================
[docs]
def default_config():
return {
"supported styles": _STYLES,
"config descr": _CONFIG_DESCR,
# --- layout ---
"figsize" : (4, 4),
# --- projection ---
"projection" : (0, 1, 2),
# --- scale ---
"scale_factor" : 1,
"pointsize" : 1,
"linewidth" : 1, #mesh style only
"mesh_long_edge": 30, #mesh only
# --- style ---
"colors_checkerboard" : ("peru", "mediumblue"), #checkerboard style only
"cmap_scatter" : "plasma", #scatter style only (3-D depth)
# --- animation ---
"frames" : 60,
"interval" : 30,
"fps" : 20,
# --- show and export ---
"show" : True,
}
[docs]
def kill_long_edges(x, threshold = 1):
long_edge = (np.diff(x, axis = 0)**2).sum(axis = -1) >= threshold
x[:-1][long_edge] = np.nan
[docs]
def atmost3D(gridpoints, projection=(0, 1, 2)):
if gridpoints.ndim <= 3:
return gridpoints
return project(gridpoints, feature_axes=projection)
def _safe_path(path: str, suffix: str) -> Path:
p = Path(path)
if p.suffix != suffix:
p = p.with_suffix(suffix)
p.parent.mkdir(parents=True, exist_ok=True)
if not p.exists():
return p
stem, parent = p.stem, p.parent
i = 1
while (candidate := parent / f"{stem}_{i}{suffix}").exists():
i += 1
return candidate
# =========================================================
# GRID PREPROCESSING
# =========================================================
def _prepare_grid(grid, DS=1):
g = np.asarray(grid).copy()
g[~np.isfinite(g)] = np.nan
n1, n2 = grid.shape[:2]
n3 = 1
is3D = ((g.ndim -1) == 3)
if is3D:
n3 = g.shape[2]
npoints = (n1*n2*n3)
if min(grid.shape[:-1]) <= 10: #downsampling (thin case)
if npoints > 100_000:
DS = max(DS, 2)
if npoints > 300_000:
DS = max(DS, 3)
if npoints > 1_000_000:
DS = max(DS, 4)
if npoints > 3_000_000:
DS = max(DS, 5)
else: #downsampling (volumic case)
if npoints > 500_000:
DS = max(DS, 2)
if npoints > 2_000_000:
DS = max(DS, 3)
if npoints > 5_000_000:
DS = max(DS, 4)
if is3D:
g = g[::DS, ::DS, ::DS, :3]
P = np.array([
[ 1, 1 / np.sqrt(3), -np.sqrt(2)/ np.sqrt(3)],
[-1, 1 / np.sqrt(3), -np.sqrt(2)/ np.sqrt(3)],
[ 0, 2 / np.sqrt(3), np.sqrt(2)/ np.sqrt(3)],
])
g = (g - np.nanmean(g, axis = (0, 1, 2), keepdims = True)) @ P
else:
g = g[::DS, ::DS, :2]
g = g - np.nanmean(g, axis = (0, 1), keepdims = True)
g_min, g_max = np.nanmin(g), np.nanmax(g)
return (g - g_min) / (g_max - g_min + 1e-8)
# =========================================================
# SURFACE EXTRACTION
# =========================================================
def _get_surfaces(grid):
D = grid.shape[-1]
if grid.ndim == 4:
surfaces = [grid[0, :, :], grid[:, 0, :], grid[:, :, -1]]
shell_mask = np.ones_like(grid, dtype = bool)
i, j, k, _ = grid.shape
mid_i, mid_j, mid_k = i//2, j//2, k//2
di, dj, dk = max(i-mid_i-10, 0), max(j-mid_j-10, 0), max(k-mid_k-10, 0)
shell_mask[mid_i-di:mid_i+di, mid_j-dj: mid_j+dj, mid_k-dk:mid_k+dk] = False
shell_mask = shell_mask.reshape(-1, D)[:, 0]
return surfaces, grid.reshape(-1, D)[shell_mask]
return [grid], grid.reshape(-1, D)
# =========================================================
# CHECKERBOARD MASK
# =========================================================
def _checkerboard_mask(surface, stepi, stepj):
ni, nj = surface.shape[:2]
ii, jj = np.indices((ni, nj))
return (((ii // stepi) % 2) == ((jj // stepj) % 2)).reshape(-1)
# =========================================================
# CORE — BUILD ANIMATION
# =========================================================
def _build_animation(grid, cfg):
"""
Layout depends on style:
- "checkerboard" : tile the surface (checkerboard -like).
- "mesh" : mesh the surface (horizontal and vertical lines).
- "scatter" : scatter the surface (depth map if ndim >= 3).
"""
sf = cfg["scale_factor"]
style = cfg["style"]
anim = cfg["animate"]
assert style in ["mesh", "checkerboard", "scatter"], \
f"Unknown plot style {style}, must be 'mesh', 'checkerboard' or 'scatter'"
if anim and (style == "checkerboard"):
scales = [8]
lws = [-1]
elif (style == "checkerboard"):
scales = [2, 4, 8]
lws = [-1, -1, -1]
elif (style == "mesh"):
scales = [2, 4, 8, 16]
lws = [1.5, 1, 0.7, 0.4]
lws = [l*cfg["linewidth"] for l in lws]
else: # scatter — scales/lws unused but keep variables defined
scales = [1]
lws = [-1]
scales = [sf*sc for sc in scales]
maxscales = max(scales)
DS = 1
frames = cfg["frames"]
colors = cfg["colors_checkerboard"]
figsize = cfg["figsize"]
cmap = cfg["cmap_scatter"]
# --- Precompute both endpoints ---
coords = [np.linspace(0, 1, s) for s in grid.shape[:-1]]
identity_raw = np.stack(np.meshgrid(*coords, indexing="ij"), axis=-1)
g_prep = _prepare_grid(grid, DS=DS)
id_prep = _prepare_grid(identity_raw, DS=DS)
is_3d = g_prep.ndim == 4
# --- Point size normalisation ---
n1, n2 = g_prep.shape[:2]
n3 = g_prep.shape[2] if is_3d else 0
npoints = n1 * n2 + (n1 * n3 + n2 * n3 if is_3d else 0)
pt_size = cfg["pointsize"] * (80_000 / npoints)
def interp(t):
return (1 - t) * g_prep + t * id_prep
# -------------------------------------------------------
# Build artists at t=0/1
# -------------------------------------------------------
surfaces0, full0 = _get_surfaces(interp(0.0))
surfaces1, full1 = _get_surfaces(interp(1.0))
if style == "mesh":
long_edge = 0
axes = [0, 1, 2] if n3 > 1 else [0, 1]
for axis in axes:
long_edge += np.nanmedian((np.diff(g_prep, axis = axis)**2).sum(axis = -1))
long_edge *= ((cfg["mesh_long_edge"])**2)/len(axes)
if is_3d:
depth = full0[:, 2]
order = np.argsort(depth)
full0, full1 = np.ascontiguousarray(full0[order]), np.ascontiguousarray(full1[order])
if style == "scatter":
for i in range(len(surfaces0)):
D = surfaces0[i].shape[-1]
s0, s1 = surfaces0[i].reshape(-1, D), surfaces1[i].reshape(-1, D)
depth = s0[:, 2]
order = np.argsort(depth)
surfaces0[i], surfaces1[i] = s0[order], s1[order]
def _lerp(t):
return [(1 - t) * s0 + t * s1 for (s0, s1) in zip(surfaces0, surfaces1)], (1 - t) * full0 + t * full1
# -------------------------------------------------------
# Figure setup
# -------------------------------------------------------
if style == "checkerboard":
fig, axes = plt.subplots(
1, len(scales),
figsize=(figsize[0] * len(scales), figsize[1])
)
axes = [axes] if len(scales) == 1 else list(axes)
for ax in axes:
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.axis("off")
ax.set_aspect("equal")
elif style == "scatter":
fig, ax_single = plt.subplots(1, 1, figsize=figsize)
ax_single.set_xlim(-0.05, 1.05)
ax_single.set_ylim(-0.05, 1.05)
ax_single.axis("off")
ax_single.set_aspect("equal")
axes = [ax_single]
else:
fig, ax_single = plt.subplots(1, 1, figsize=figsize)
ax_single.set_xlim(-0.05, 1.05)
ax_single.set_ylim(-0.05, 1.05)
ax_single.axis("off")
ax_single.set_aspect("equal")
axes = [ax_single] * len(scales)
# -------------------------------------------------------
# checkerboard / mesh (unchanged logic below)
# -------------------------------------------------------
seen_axes = []
bg_artists = []
for ax in axes:
if id(ax) not in seen_axes:
seen_axes.append(id(ax))
bg_artists.append(
ax.scatter(
full0[:, 0], full0[:, 1],
c="lightgrey", s=pt_size, linewidths=0
)
)
scale_artists = []
for ax, sc, lw in zip(axes, scales, lws):
surface_artists = []
for s in surfaces0:
ni, nj = s.shape[:2]
if style == "scatter":
if is_3d:
sc = ax.scatter(s[:, 0], s[:, 1], c = s[:, 2], cmap = cfg["cmap_scatter"], s=pt_size, linewidths=0)
surface_artists.append(("scatter", sc))
elif style == "checkerboard":
stepi = int(max(ni // sc, 1))
stepj = int(max(nj // sc, 1))
mask = _checkerboard_mask(s, stepi, stepj)
pts = s.reshape(-1, s.shape[-1])
sc_a = ax.scatter(pts[ mask, 0], pts[ mask, 1],
c=colors[0], s=pt_size, linewidths=0)
sc_b = ax.scatter(pts[~mask, 0], pts[~mask, 1],
c=colors[1], s=pt_size, linewidths=0)
surface_artists.append(("checker", mask, sc_a, sc_b))
elif style == "mesh":
basei = int(max(ni//maxscales, 1))
stepi = basei*int(maxscales//sc)
basej = int(max(nj//maxscales, 1))
stepj = basej*int(maxscales//sc)
for i in list(range(0, ni, stepi)):
kill_long_edges(s[i], threshold = long_edge)
for j in list(range(0, nj, stepj)):
kill_long_edges(s[:, j], threshold = long_edge)
lines_i = [
ax.plot(s[i, :, 0], s[i, :, 1], color="black", lw=lw)[0]
for i in list(range(0, ni, stepi))
]
lines_j = [
ax.plot(s[:, j, 0], s[:, j, 1], color="black", lw=lw)[0]
for j in list(range(0, nj, stepj))
]
surface_artists.append(
("mesh", ni, nj, stepi, stepj, lines_i, lines_j)
)
scale_artists.append(surface_artists)
plt.tight_layout()
def _iter_all_artists():
yield from bg_artists
for surf_list in scale_artists:
for info in surf_list:
if info[0] == "scatter":
yield info[1]
elif info[0] == "checker":
yield info[2]; yield info[3]
elif info[0] == "mesh":
yield from info[5]; yield from info[6]
def _update(frame):
t = 2* (frame / max(frames - 1, 1))
if t > 1:
t = 2-t
surfaces, full = _lerp(t)
for bg in bg_artists:
bg.set_offsets(full[:, :2])
for surf_list in scale_artists:
for s_idx, info in enumerate(surf_list):
s = surfaces[s_idx]
pts = s.reshape(-1, s.shape[-1])
if info[0] == "scatter":
_, sc = info
sc.set_offsets(pts[:, :2])
elif info[0] == "checker":
_, mask, sc_a, sc_b = info
if pts.shape[1] == 3:
order = np.argsort(pts[:, 2])
pts = pts[order]
mask_sorted = mask[order]
else:
mask_sorted = mask
sc_a.set_offsets(pts[ mask_sorted, :2])
sc_b.set_offsets(pts[~mask_sorted, :2])
elif info[0] == "mesh":
_, ni, nj, stepi, stepj, lines_i, lines_j = info
for k, i in enumerate(range(0, ni, stepi)):
lines_i[k].set_data(s[i, :, 0], s[i, :, 1])
for k, j in enumerate(range(0, nj, stepj)):
lines_j[k].set_data(s[:, j, 0], s[:, j, 1])
return list(_iter_all_artists())
ani = animation.FuncAnimation(
fig, _update,
frames=frames,
interval=cfg["interval"],
blit=True,
)
return fig, ani
# =========================================================
# PUBLIC API
# =========================================================
[docs]
def sqplot(grid, verbose, style="checkerboard", animate=False,
save=True, save_path="sqrnet/plot", cfg=None):
"""
Render a structured grid — static snapshot or morphing animation.
Parameters
----------
grid : np.ndarray
Input structured grid of shape (..., D) where D is 2 (2-D) or 3 (3-D).
style : str
One of ``"checkerboard"``, ``"mesh"``, or ``"scatter"``.
``"scatter"`` draws every point as a plain dot: coloured by
depth (3rd coordinate after the isometric rotation) in 3-D,
or in black in 2-D. No background grid is shown.
cfg : dict, optional
Rendering configuration. Defaults to ``default_config()``.
Missing keys are filled in from the defaults, so you only need
to specify what you want to override. Extra key for scatter:
``"cmap_scatter"`` (default ``"viridis"``).
Returns
-------
fig : matplotlib.figure.Figure
ani : matplotlib.animation.FuncAnimation or None
``None`` when ``cfg["animate"]`` is False.
"""
full_cfg = default_config()
if cfg is not None:
full_cfg.update(cfg)
full_cfg.update({"style": style,
"animate":animate,
"save": save,
"save_path": save_path
})
grid = atmost3D(grid, projection=full_cfg["projection"])
animate = full_cfg["animate"]
save_path = full_cfg["save_path"]
if animate:
fig, ani = _build_animation(grid, full_cfg)
if save_path is not None:
has_ffmpeg = animation.writers.is_available("ffmpeg")
suffix = ".mp4" if has_ffmpeg else ".gif"
writer = "ffmpeg" if has_ffmpeg else "pillow"
save_path = _safe_path(save_path, suffix)
if verbose >=1:
print(f"figure will be saved at {save_path}")
N = np.prod(grid.shape[:-1])
rounded = 10 ** int(np.log10(N))
if N >= 50_000 and suffix == ".gif":
if style == "scatter":
print(f"ffmpeg unavailable. Npoints = {rounded} and style is 'scatter' -> could take 1/2 minutes")
else:
print(f"ffmpeg unavailable. Npoints = {rounded} -> could take 20/30 seconds")
ani.save(save_path, writer=writer, fps=full_cfg["fps"])
else:
if cfg["show"]:
plt.show()
return fig, ani
else:
static_cfg = full_cfg.copy()
static_cfg["frames"] = 1
fig, _ = _build_animation(grid, static_cfg)
if save_path is not None:
save_path = _safe_path(save_path, ".png")
print(f"figure will be saved at {save_path}")
fig.savefig(save_path, bbox_inches="tight")
if cfg["show"]:
plt.show()
return fig, None