import numpy as np
from warnings import warn
from .core import carthesian_sort
from .artist import sqplot, default_config
from .sampler import samplepoints
from .neighbormap import neighbormap
from .utils import dualgrid, dualgridflat, from_backend, to_backend, printmatrix
[docs]
class SquareNet:
"""
Grid straightening algorithm for D-dimensional point clouds.
SquareNet reorganizes an unordered set of points into a structured grid
by iteratively sorting indices along each spatial axis. The goal is to
enforce a topology where neighboring points in the original space
are also close on the grid.
The grid is internally represented as an array of indices of shape
``(N1, ..., ND)``, where D is the dimensionality of the embedding space.
Parameters
----------
gridshape : tuple of int
Shape of the target grid. The product of these dimensions must
exactly match the total number of input points (bijective gridification).
max_iter : int, default=100
Maximum number of iterations allowed for the sorting procedure.
warnings_ : bool, default=True
If True, emits a ``ConvergenceWarning`` if the grid is not
perfectly ordered at the end of the process.
backend : {"numpy", "torch" or "jax"}, default="numpy"
input and output backend.
Attributes
----------
grid : ndarray
The structured grid of indices, with shape ``gridshape``.
invert_grid : ndarray, jax or numpy
Inverse mapping from point indices to grid coordinates.
invgridflat : ndarray, jax or numpy
Flattened version of the inverse mapping for fast indexing.
learning_curve : list of float
History of the disorder metric across iterations.
1. Run ``fit()`` with your preferred backend.
2. Call ``SquareNet.map()`` on any feature indexed like the points (N, *C)
To build a tensor version of it (*G, *C) based on the points learned multi-indexes
"""
def __init__(self, gridshape, max_iter = 1_000, warnings_=True,
verbose = 2, backend = "numpy", device = "cpu"):
self.gridshape = tuple(gridshape)
self.D = len(self.gridshape)
self.N = int(np.prod(self.gridshape))
self.max_iter = max_iter
self.verbose = verbose
self.backend = backend
self.device = device
if backend not in {"numpy", "jax", "torch"}:
raise ValueError(f"Unknown backend '{backend}' Should be 'numpy' 'torch' or 'jax'")
if backend == "numpy":
self.xp = np
elif backend == "jax":
import jax.numpy as jnp
self.xp = jnp
elif backend == "torch":
import torch
self.xp = torch
self.warnings_ = warnings_
self.plot_config = default_config()
self._reset_state()
# ------------------------------------------------------------------
# Internal utilities
# ------------------------------------------------------------------
[docs]
def to_backend(self, x):
return to_backend(x, backend = self.backend, device = self.device, warnings_ = self.warnings_)
def _reset_state(self):
"""Reset internal state to initial configuration."""
if self.backend == "torch":
self.grid = self.xp.arange(
self.N,
dtype=self.xp.int32,
device=self.device,
).reshape(self.gridshape)
else:
self.grid = self.xp.arange(
self.N,
dtype=self.xp.int32,
).reshape(self.gridshape)
self.invert_grid = dualgrid(self.grid, self.xp, self.N, self.gridshape, self.D)
self.invgridflat = dualgridflat(self.grid, self.xp, self.N)
self.points = None
self.pointsmaped = None
self.learning_curve = []
self.fitted = False
def _validate_points(self, points):
"""Validate input point cloud."""
if isinstance(points, str):
points = samplepoints(method=points, size=(self.N, self.D))
points = self.to_backend(points)
N, D = points.shape
assert N == self.N, (
f"Input points ({N}) must match grid size ({self.N})."
"For injective gridification, explicitly add fictive points with +- infinite coordinates",
"Which will then be placed in the corners of the grid"
)
assert D == self.D, (
f"Input dimension ({D}) must match D={self.D}. "
"For manifold data, explicitly include singleton dimensions, e.g. gridshape = (100, 100, 1)."
)
return points
# ------------------------------------------------------------------
# Core API
# ------------------------------------------------------------------
[docs]
def fit(self, points, method = "fast", fit_with_numpy = False):
"""
Fit the grid to a point cloud.
Parameters
----------
points : ndarray of shape (N, D) or str
Input point cloud or sampling method name.
method : fast, robust or ultimate.
robust can be up to 5 time slower, but
will probably give a better grid.
ultimate can be up to 30 time slower
but will give the best results among the
three methods.
fit_with_numpy :
wether to apply conversion to numpy
(just for the fit) and fit with numpy method.
Might be faster, always consider it as an option.
Default is to fit on the backend (and device)
given at init e.g. numpy, torch or jax.
Returns
-------
None
The method updates the internal state of the grid in-place.
see ``squarenet.core``
"""
old_backend = self.backend
if fit_with_numpy:
#just for the fit, will be updated after
self.backend = "numpy"
max_iter = self.max_iter
verbose = self.verbose
def _log(msg):
if verbose >= 2:
print(msg)
def _section(title):
if verbose >= 2:
print(f"\n=== {title} ===")
points = self._validate_points(points)
self.points = points.clone() if hasattr(points, "clone") else points.copy()
_log("Starting gridification... available method [fast, robust, ultimate]")
_log(f"selected {method}")
_section(f"Carthesian sort")
mi = max_iter
if method == "fast":
_log(f"(max iter: {mi})")
if method == "robust":
_log(f"(max iter: 2 x {mi})")
if method == "ultimate":
_log(f"(max iter: 2 x {mi} + 4 x {mi} + {mi})")
self.grid, lc = carthesian_sort(self.grid, points, max_iter=max_iter, method = method,
backend = self.backend, verbose = self.verbose)
# --------------------------------------------------
# Finalization
# --------------------------------------------------
if isinstance(lc, tuple):
lc, last_iter = lc
self.learning_curve = list(lc)[:last_iter +1]
else:
self.learning_curve = list(lc)
last_error = self.learning_curve[-1]
last_iter = len(self.learning_curve) - 1
_section("Final Status")
if verbose >= 1:
if last_error == 0:
print(f"succesfully sorted at iteration {last_iter}")
else:
print(f"not fully sorted (iter = {max_iter}, error={last_error/self.N:.4%})")
if self.warnings_:
warn(
"Sorting did not converge to zero. "
"Consider increasing `max_iter`.",
ConvergenceWarning,
stacklevel=2,
)
if fit_with_numpy:
self.backend = old_backend
self.grid = to_backend(self.grid, backend = self.backend, device = self.device)
self.points = to_backend(self.points, backend = self.backend, device = self.device)
# --------------------------------------------------
# Update mappings
# --------------------------------------------------
self.invert_grid = dualgrid(self.grid, self.xp, self.N, self.gridshape, self.D)
self.invgridflat = dualgridflat(self.grid, self.xp, self.N)
self.fitted = True
self.pointsmaped = self.map(self.points)
# ------------------------------------------------------------------
# Mapping
# ------------------------------------------------------------------
[docs]
def map(self, features):
"""
Map cloud data to grid structure.
Parameters
----------
features : jax, or numpy ndarray of shape (N, *C)
Returns
-------
jax or numpy ndarray of shape (*gridshape, *C)
"""
return features[self.grid]
[docs]
def invert_map(self, features):
"""
Map grid data back to cloud ordering.
Parameters
----------
features : jax or numpy ndarray of shape (*gridshape, *C)
Returns
-------
jax or numpy ndarray of shape (N, *C)
"""
return features.reshape(
self.N, *features.shape[self.D:]
)[self.invgridflat]
[docs]
def mapidx(self, index):
"""
Cloud index -> grid index.
Returns
-------
tuple of arrays
directly usable for indexing.
"""
coords = self.invert_grid[index]
return tuple(coords.T)
[docs]
def invert_mapidx(self, index):
"""
Grid index -> cloud index.
"""
index = self.xp.asarray(index)
if index.ndim == 1:
return self.grid[tuple(index)]
return self.grid[tuple(index.T)]
[docs]
def search_sorted(self, X, n_iter=2, side = "left"):
"""
X must be a SINGLE point, search_sorted will return a multiindex IX in the grid.
IX is the grid index of PROBABLY one of the closest points to X in the first (side = 'left')
or last (side = 'right') quadrant (relative to X) of the space.
The search is a greedy coordinate descent: each dimension is refined by applying 1D searchsorted
on row, then columns... Augmenting n_iter increases accuracy.
"""
IX = self.xp.zeros(self.D, dtype=self.xp.int32)
Y = self.pointsmaped
gdims = self.gridshape
side_offset = -1 if side == "left" else 0
for _ in range(n_iter):
for d in range(self.D):
IXn = list(IX) + [d]
IXn[d] = slice(None)
newI = self.xp.searchsorted(Y[tuple(IXn)], X[d], side = side) + side_offset
if newI < 0:
newI = 0
if newI > gdims[d] - 1:
newI = gdims[d] - 1
IX[d] = newI
return IX
# ------------------------------------------------------------------
# Visualization
# ------------------------------------------------------------------
[docs]
def plot(self,
style="checkerboard",
animate=False,
save = True,
save_path="sqrnet/plot",
**kwargs
):
"""
Display the mapped grid as a static figure or animation.
Many rendering options (scales, colors, figsize, DS, frames …)
can be passed as keyword arguments or configured once via
``self.plot_config[key] = value``.
print(self.plot_config) to see all available parameters).
Parameters
----------
style : str
Rendering style: ``"checkerboard"``, ``"mesh"`` or ``"scatter"``.
animate : bool
If True, produce a morphing animation (grid → identity).
If False, render a static plot.
save_path : str or None
None → display with ``plt.show()``.
str → save to disk
Returns
-------
fig : matplotlib.figure.Figure
ani : matplotlib.animation.FuncAnimation or None
see ``squarenet.artist``
"""
gridpoints = self.pointsmaped
if hasattr(gridpoints, "detach"):
gridpoints = gridpoints.detach().cpu().numpy()
gridpoints = np.ascontiguousarray(np.asarray(gridpoints))
if save == False:
save_path = None
plot_config = self.plot_config.copy()
plot_config.update(kwargs)
plot_config.update({"style": style, "animate": animate, "save_path": save_path})
return sqplot(gridpoints, self.verbose,
style = style, animate = animate,
save = save, save_path = save_path,
cfg = plot_config)
[docs]
def neighbormap(self, max_sample_size = 20_000_000, max_window_size = 31, criterion="rank", thresholdcut=1,
projection=(0, 1), log2=False):
"""
Compute and display a neighborhood map from gridded points.
For each point X (or a subsample of sample_size if the dataset is too large),
scans a square window of radius wr = ws//2 in the grid centered on X, and adds 1
to each cell if the point Y found in the cell matches the criterion.
Parameters
----------
log2 : bool, default False
If True, applies log2 scaling to counts.
max_sample_size: int, default 20 million
sample size for the number of pairs (X, Y)
max_window_size: int, default=31
Note that ``wr``, the window radius, is defined as
``wr = window_size // 2``.
Search window size, i.e. the size of the window in which the
neighbor map is computed, such that:
`gridindex(X) - gridindex(Y) <= wr``
where <= is relative to the L∞ norm on grid indices.
Warning
-------
You will get no information, and thus no garantee at all,
on what happens outside the search window. Conversely, using
a very large search window dilutes the information, since the
probability of sampling pairs for a given grid offset decreases.
Choosing the window size is therefore a tradeoff between macroscopic
information and microscopic information
criterion : str, optional
Method used to define neighborhoods ("rank" or "value").
thresholdcut : int or float, optional Cutoff threshold for connections, interpreted according to the criterion.
- If criterion is "rank": thresholdcut = k means Y matches X
if Y is among the k nearest points to X.
- If criterion is "value": thresholdcut is a distance threshold,
and Y matches X if the distance(X, Y) <= thresholdcut.
projection : tuple of int, optional
Axes of the grid used for 2D projection.
Returns
-------
None
Prints the resulting neighborhood matrix.
see ``squarenet.neighbormap``
"""
nbmap = neighbormap(
from_backend(self.pointsmaped),
self.mapidx,
sample_size = max_sample_size,
windowradius = max_window_size//2,
criterion=criterion,
thresholdcut=thresholdcut,
projection=projection
)
if log2:
arr = np.log2(nbmap + 0.01).astype(int)
print(f"neighbor map on axes {projection} (log2 count)")
else:
arr = nbmap.astype(int)
print(f"neighbor map on axes {projection}")
arr[nbmap == 0] = -1 #invalid log for 0
printmatrix(arr)
[docs]
class ConvergenceWarning(UserWarning):
"""Raised when optimization does not converge."""
pass