Source code for squarenet.utils

import numpy as np
from warnings import warn

[docs] def index_identity(shape): #index_identity[i,j,k] = [i, j, k] return np.moveaxis(np.indices(shape), 0, -1)
[docs] def dualgrid(grid, xp, N, IJ, D): # torch if xp.__name__ == "torch": identity = xp.stack( xp.meshgrid( *[ xp.arange( s, dtype=grid.dtype, device=grid.device, ) for s in IJ ], indexing="ij", ), dim=-1, ).reshape(N, D) out = xp.empty( (N, D), dtype=grid.dtype, device=grid.device, ) out[grid.reshape(-1)] = identity return out # numpy / jax identity = xp.stack( xp.meshgrid( *[xp.arange(s, dtype=grid.dtype) for s in IJ], indexing="ij", ), axis=-1, ).reshape(N, D) # jax if xp.__name__.startswith("jax"): out = xp.zeros((N, D), dtype=grid.dtype) return out.at[grid.reshape(-1)].set(identity) # numpy out = xp.zeros((N, D), dtype=grid.dtype) out[grid.reshape(-1)] = identity return out
[docs] def dualgridflat(grid, xp, N): gr = grid.reshape(-1) # torch if xp.__name__ == "torch": identity = xp.arange( N, dtype=grid.dtype, device=grid.device, ) out = xp.empty( N, dtype=grid.dtype, device=grid.device, ) out[gr] = identity return out # numpy / jax identity = xp.arange(N, dtype=grid.dtype) # jax if xp.__name__.startswith("jax"): out = xp.zeros(N, dtype=grid.dtype) return out.at[gr].set(identity) # numpy out = xp.zeros(N, dtype=grid.dtype) out[gr] = identity return out
[docs] def breakpoint(): raise RuntimeError("STOP checkpoint.\n Everything allright...")
[docs] def project(gridpoints, feature_axes=(0, 1), index=0): grid_ndim = gridpoints.ndim - 1 selection = [index] * grid_ndim for i, ax in enumerate(feature_axes): selection[ax] = slice(None) x = gridpoints[tuple(selection)] current_order = sorted(range(len(feature_axes)), key=lambda i: feature_axes[i]) new_order = np.argsort(current_order) x = x.transpose(list(new_order) + [len(feature_axes)]) x = x[..., list(feature_axes)] return x
[docs] def from_backend(x): """ Convert torch/jax/numpy array to numpy safely. """ if isinstance(x, np.ndarray): return x # torch if hasattr(x, "detach"): return x.detach().cpu().numpy() # jax if hasattr(x, "__array__"): return np.asarray(x) return np.asarray(x)
[docs] def to_backend(x, backend="numpy", device = "cpu", warnings_ = True): """ Convert numpy array to target backend. """ if backend == "numpy": return np.asarray(x) if backend == "torch": import torch device = torch.device(device) x_device = getattr(x, "device", None) if ( warnings_ and x_device is not None and hasattr(x_device, "type") and x_device.type == "cpu" and device.type != "cpu" ): warn( "Downgrading tensor as initial device was GPU but asked device is CPU" ) return torch.as_tensor(x, device=device) if backend == "jax": import jax.numpy as jnp return jnp.asarray(x) raise ValueError(f"Unknown backend: {backend}")
[docs] def progress_bar(it, total, bar_length=30): progress = it / total filled = int(bar_length * progress) bar = "█" * filled + "-" * max(0,(bar_length - filled-1)) if it >= total-1: print(f"\r[{bar}] {total}/{total}") else: print(f"\r[{bar}] {it}/{total}", end="")
[docs] def printmatrix(arr): max_x = arr.max(axis=0) max_y = arr.max(axis=1) max_x = np.maximum(max_x, max_x[::-1]) max_y = np.maximum(max_y, max_y[::-1]) x_idx = np.where(max_x >= 0)[0] y_idx = np.where(max_y >= 0)[0] if len(x_idx) == 0 or len(y_idx) == 0: arr = np.zeros((1, 1)) else: x0, x1 = x_idx[0], x_idx[-1] y0, y1 = y_idx[0], y_idx[-1] arr = arr[y0:y1+1, x0:x1+1] width = max(len(str(x)) for x in arr.flatten()) hx, hy = arr.shape wrx, wry = hx // 2, hy//2 # center marker = f"{'■':>{width}}" for i, row in enumerate(arr): line = [] for j, x in enumerate(row): if i == wrx and j == wry: line.append(marker) else: line.append(f"{x:{width}d}" if x >= 0 else " " * width) print(" ".join(line))
[docs] def show_search_result(left, right, true, points, sn): import matplotlib.pyplot as plt print("true index", sn.mapidx(true)) print("with search sorted:", left, right) point_found_l = points[sn.invert_mapidx(left)] point_found_r = points[sn.invert_mapidx(right)] plt.figure(figsize = (6, 6)) plt.scatter(points[:, 0], points[:, 1], color = "grey", s = 10000/len(points)) plt.scatter(point_found_l[0], point_found_l[1], s= 200, alpha = 0.5, color = "blue", label = "found left") plt.scatter(point_found_r[0], point_found_r[1], s= 200, alpha = 0.5, color = "green", label = "found right") plt.scatter(points[true][0], points[true][1], s= 200, marker = "x", color = "red", label = "true") plt.axis("equal") plt.axis("off") plt.legend(loc = "upper right") plt.show()