pydrex.utils
PyDRex: Miscellaneous utility methods.
1"""> PyDRex: Miscellaneous utility methods.""" 2 3import os 4import platform 5import subprocess 6import sys 7from functools import wraps 8 9import dill 10import numba as nb 11import numpy as np 12from matplotlib.collections import PathCollection 13from matplotlib.legend_handler import HandlerLine2D, HandlerPathCollection 14from matplotlib.pyplot import Line2D 15from matplotlib.transforms import ScaledTranslation 16 17from pydrex import logger as _log 18 19 20def import_proc_pool() -> tuple: 21 """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`. 22 23 Import a process `Pool` object either from Ray of from Python's stdlib. 24 Both offer the same API, the Ray implementation will be preferred if available. 25 Using the `Pool` provided by Ray allows for distributed memory multiprocessing. 26 27 Returns a tuple containing the `Pool` object and a boolean flag which is `True` if 28 Ray is available. 29 30 """ 31 try: 32 from ray.util.multiprocessing import Pool 33 34 has_ray = True 35 except ImportError: 36 from multiprocessing import Pool 37 38 has_ray = False 39 return Pool, has_ray 40 41 42def in_ci(platform: str) -> bool: 43 """Check if we are in a GitHub runner with the given operating system.""" 44 # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables 45 return sys.platform == platform and os.getenv("CI") is not None 46 47 48class SerializedCallable: 49 """A serialized version of the callable f. 50 51 Serialization is performed using the dill library. The object is safe to pass into 52 `multiprocessing.Pool.map` and its alternatives. 53 54 .. note:: To serialize a lexical closure (i.e. a function defined inside a 55 function), use the `serializable` decorator. 56 57 """ 58 59 def __init__(self, f): 60 self._f = dill.dumps(f, protocol=5, byref=True) 61 62 def __call__(self, *args, **kwargs): 63 return dill.loads(self._f)(*args, **kwargs) 64 65 66def serializable(f): 67 """Make decorated function serializable. 68 69 .. warning:: The decorated function cannot be a method, and it will loose its 70 docstring. It is not possible to use `functools.wraps` to mitigate this. 71 72 """ 73 return SerializedCallable(f) 74 75 76def defined_if(cond): 77 """Only define decorated function if `cond` is `True`.""" 78 79 def _defined_if(f): 80 def not_f(*args, **kwargs): 81 # Throw the same as we would get from `type(undefined_symbol)`. 82 raise NameError(f"name '{f.__name__}' is not defined") 83 84 @wraps(f) 85 def wrapper(*args, **kwargs): 86 if cond: 87 return f(*args, **kwargs) 88 return not_f(*args, **kwargs) 89 90 return wrapper 91 92 return _defined_if 93 94 95@nb.njit(fastmath=True) 96def strain_increment(dt, velocity_gradient): 97 """Calculate strain increment for a given time increment and velocity gradient. 98 99 Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the 100 “(engineering) shear strain” increment. 101 102 """ 103 return ( 104 np.abs(dt) 105 * np.abs( 106 np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2) 107 ).max() 108 ) 109 110 111@nb.njit 112def apply_gbs( 113 orientations, fractions, gbs_threshold, orientations_prev, n_grains 114) -> tuple[np.ndarray, np.ndarray]: 115 """Apply grain boundary sliding for small grains.""" 116 mask = fractions < (gbs_threshold / n_grains) 117 # _log.debug( 118 # "grain boundary sliding activity (volume percentage): %s", 119 # len(np.nonzero(mask)) / len(fractions), 120 # ) 121 # No rotation: carry over previous orientations. 122 orientations[mask, :, :] = orientations_prev[mask, :, :] 123 fractions[mask] = gbs_threshold / n_grains 124 fractions /= fractions.sum() 125 # _log.debug( 126 # "grain volume fractions: median=%e, min=%e, max=%e, sum=%e", 127 # np.median(fractions), 128 # np.min(fractions), 129 # np.max(fractions), 130 # np.sum(fractions), 131 # ) 132 return orientations, fractions 133 134 135@nb.njit 136def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 137 """Extract deformation gradient, orientation matrices and grain sizes from y.""" 138 deformation_gradient = y[:9].reshape((3, 3)) 139 orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1) 140 fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None) 141 fractions /= fractions.sum() 142 return deformation_gradient, orientations, fractions 143 144 145def remove_nans(a): 146 """Remove NaN values from array.""" 147 a = np.asarray(a) 148 return a[~np.isnan(a)] 149 150 151def remove_dim(a, dim): 152 """Remove all values corresponding to dimension `dim` from an array. 153 154 Note that a `dim` of 0 refers to the “x” values. 155 156 Examples: 157 158 >>> a = [1, 2, 3] 159 >>> remove_dim(a, 0) 160 array([2, 3]) 161 >>> remove_dim(a, 1) 162 array([1, 3]) 163 >>> remove_dim(a, 2) 164 array([1, 2]) 165 166 >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 167 >>> remove_dim(a, 0) 168 array([[5, 6], 169 [8, 9]]) 170 >>> remove_dim(a, 1) 171 array([[1, 3], 172 [7, 9]]) 173 >>> remove_dim(a, 2) 174 array([[1, 2], 175 [4, 5]]) 176 177 """ 178 _a = np.asarray(a) 179 for i, _ in enumerate(_a.shape): 180 _a = np.delete(_a, [dim], axis=i) 181 return _a 182 183 184def add_dim(a, dim, val=0): 185 """Add entries of `val` corresponding to dimension `dim` to an array. 186 187 Note that a `dim` of 0 refers to the “x” values. 188 189 Examples: 190 191 >>> a = [1, 2] 192 >>> add_dim(a, 0) 193 array([0, 1, 2]) 194 >>> add_dim(a, 1) 195 array([1, 0, 2]) 196 >>> add_dim(a, 2) 197 array([1, 2, 0]) 198 199 >>> add_dim([1.0, 2.0], 2) 200 array([1., 2., 0.]) 201 202 >>> a = [[1, 2], [3, 4]] 203 >>> add_dim(a, 0) 204 array([[0, 0, 0], 205 [0, 1, 2], 206 [0, 3, 4]]) 207 >>> add_dim(a, 1) 208 array([[1, 0, 2], 209 [0, 0, 0], 210 [3, 0, 4]]) 211 >>> add_dim(a, 2) 212 array([[1, 2, 0], 213 [3, 4, 0], 214 [0, 0, 0]]) 215 216 """ 217 _a = np.asarray(a) 218 for i, _ in enumerate(_a.shape): 219 _a = np.insert(_a, [dim], 0, axis=i) 220 return _a 221 222 223def default_ncpus() -> int: 224 """Get a safe default number of CPUs available for multiprocessing. 225 226 On Linux platforms that support it, the method `os.sched_getaffinity()` is used. 227 On Mac OS, the command `sysctl -n hw.ncpu` is used. 228 On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried. 229 If any of these fail, a fallback of 1 is used and a warning is logged. 230 231 """ 232 try: 233 match platform.system(): 234 case "Linux": 235 return len(os.sched_getaffinity(0)) - 1 # May raise AttributeError. 236 case "Darwin": 237 # May raise CalledProcessError. 238 out = subprocess.run( 239 ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True 240 ) 241 return int(out.stdout.strip()) - 1 242 case "Windows": 243 return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1 244 case _: 245 return 1 246 except (AttributeError, subprocess.CalledProcessError, KeyError): 247 return 1 248 249 250def diff_like(a): 251 """Get forward difference of 2D array `a`, with repeated last elements. 252 253 The repeated last elements ensure that output and input arrays have equal shape. 254 255 Examples: 256 257 >>> diff_like(np.array([1, 2, 3, 4, 5])) 258 array([[1, 1, 1, 1, 1]]) 259 260 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]])) 261 array([[1, 1, 1, 1, 1], 262 [2, 3, 3, 1, 1]]) 263 264 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]])) 265 array([[ 1., 1., 1., 1., 1.], 266 [ 2., 3., 3., 1., 1.], 267 [-1., 0., 0., inf, nan]]) 268 269 """ 270 a2 = np.atleast_2d(a) 271 return np.diff( 272 a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1)) 273 ) 274 275 276def angle_fse_simpleshear(strain): 277 """Get angle of FSE long axis anticlockwise from the X axis in simple shear.""" 278 return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain)) 279 280 281def lag_2d_corner_flow(θ): 282 """Get predicted grain orientation lag for 2D corner flow. 283 284 See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222). 285 286 """ 287 _θ = np.ma.masked_less(θ, 1e-15) 288 return (_θ * (_θ**2 + np.cos(_θ) ** 2)) / ( 289 np.tan(_θ) * (_θ**2 + np.cos(_θ) ** 2 - _θ * np.sin(2 * _θ)) 290 ) 291 292 293@nb.njit(fastmath=True) 294def quat_product(q1, q2): 295 """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.""" 296 return [ 297 *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]), 298 q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]), 299 ] 300 301 302def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs): 303 """Redraw legend on matplotlib axis or figure. 304 305 Transparency is removed from legend symbols. 306 If `fig` is not None and `remove_all` is True, 307 all legends are first removed from the parent figure. 308 Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default, 309 or `matplotlib.figure.Figure.legend` if `fig` is not None. 310 311 If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes 312 instead of taking up space in the original axes. This option requires `fig=None`. 313 314 .. warning:: 315 Note that if `fig` is not `None`, the legend may be cropped from the saved 316 figure due to a Matplotlib bug. In this case, it is required to add the 317 arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`, 318 where `legend` is the object returned by this function. To prevent the legend 319 from consuming axes/subplot space, it is further required to add the lines: 320 `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)` 321 and `fig.set_layout_engine("none")` before saving the figure. 322 323 """ 324 handler_map = { 325 PathCollection: HandlerPathCollection( 326 update_func=_remove_legend_symbol_transparency 327 ), 328 Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency), 329 } 330 if fig is None: 331 legend = ax.get_legend() 332 if legend is not None: 333 handles, labels = ax.get_legend_handles_labels() 334 legend.remove() 335 if legendax is not None: 336 legendax.axis("off") 337 return legendax.legend(handles, labels, handler_map=handler_map, **kwargs) 338 return ax.legend(handler_map=handler_map, **kwargs) 339 else: 340 if legendax is not None: 341 _log.warning("ignoring `legendax` argument which requires `fig=None`") 342 for legend in fig.legends: 343 if legend is not None: 344 legend.remove() 345 if remove_all: 346 for ax in fig.axes: 347 legend = ax.get_legend() 348 if legend is not None: 349 legend.remove() 350 return fig.legend(handler_map=handler_map, **kwargs) 351 352 353def add_subplot_labels( 354 mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs 355): 356 """Add subplot labels to axes mosaic. 357 358 Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels. 359 If `labelmap` is None, the keys in `axs` will be used as the labels by default. 360 361 If `internal` is `False` (default), the axes titles will be used. 362 Otherwise, internal labels will be drawn with `ax.text`, 363 in which case `loc` must be a tuple of floats. 364 365 Any axes in `axs` corresponding to the special key `legend` are skipped. 366 367 """ 368 for txt, ax in mosaic.items(): 369 if txt.lower() == "legend": 370 continue 371 _txt = labelmap[txt] if labelmap is not None else txt 372 if internal: 373 trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans) 374 if isinstance(loc, str): 375 raise ValueError( 376 "'loc' argument must be a sequence of float when 'internal' is 'True'" 377 ) 378 ax.text( 379 *loc, 380 _txt, 381 transform=ax.transAxes + trans, 382 fontsize=fontsize, 383 bbox={ 384 "facecolor": (1.0, 1.0, 1.0, 0.3), 385 "edgecolor": "none", 386 "pad": 3.0, 387 }, 388 ) 389 else: 390 ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs) 391 392 393def _remove_legend_symbol_transparency(handle, orig): 394 """Remove transparency from symbols used in a Matplotlib legend.""" 395 # https://stackoverflow.com/a/59629242/12519962 396 handle.update_from(orig) 397 handle.set_alpha(1)
21def import_proc_pool() -> tuple: 22 """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`. 23 24 Import a process `Pool` object either from Ray of from Python's stdlib. 25 Both offer the same API, the Ray implementation will be preferred if available. 26 Using the `Pool` provided by Ray allows for distributed memory multiprocessing. 27 28 Returns a tuple containing the `Pool` object and a boolean flag which is `True` if 29 Ray is available. 30 31 """ 32 try: 33 from ray.util.multiprocessing import Pool 34 35 has_ray = True 36 except ImportError: 37 from multiprocessing import Pool 38 39 has_ray = False 40 return Pool, has_ray
Import either ray.util.multiprocessing.Pool
or multiprocessing.Pool
.
Import a process Pool
object either from Ray of from Python's stdlib.
Both offer the same API, the Ray implementation will be preferred if available.
Using the Pool
provided by Ray allows for distributed memory multiprocessing.
Returns a tuple containing the Pool
object and a boolean flag which is True
if
Ray is available.
43def in_ci(platform: str) -> bool: 44 """Check if we are in a GitHub runner with the given operating system.""" 45 # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables 46 return sys.platform == platform and os.getenv("CI") is not None
Check if we are in a GitHub runner with the given operating system.
49class SerializedCallable: 50 """A serialized version of the callable f. 51 52 Serialization is performed using the dill library. The object is safe to pass into 53 `multiprocessing.Pool.map` and its alternatives. 54 55 .. note:: To serialize a lexical closure (i.e. a function defined inside a 56 function), use the `serializable` decorator. 57 58 """ 59 60 def __init__(self, f): 61 self._f = dill.dumps(f, protocol=5, byref=True) 62 63 def __call__(self, *args, **kwargs): 64 return dill.loads(self._f)(*args, **kwargs)
A serialized version of the callable f.
Serialization is performed using the dill library. The object is safe to pass into
multiprocessing.Pool.map
and its alternatives.
To serialize a lexical closure (i.e. a function defined inside a
function), use the serializable
decorator.
67def serializable(f): 68 """Make decorated function serializable. 69 70 .. warning:: The decorated function cannot be a method, and it will loose its 71 docstring. It is not possible to use `functools.wraps` to mitigate this. 72 73 """ 74 return SerializedCallable(f)
Make decorated function serializable.
The decorated function cannot be a method, and it will loose its
docstring. It is not possible to use functools.wraps
to mitigate this.
77def defined_if(cond): 78 """Only define decorated function if `cond` is `True`.""" 79 80 def _defined_if(f): 81 def not_f(*args, **kwargs): 82 # Throw the same as we would get from `type(undefined_symbol)`. 83 raise NameError(f"name '{f.__name__}' is not defined") 84 85 @wraps(f) 86 def wrapper(*args, **kwargs): 87 if cond: 88 return f(*args, **kwargs) 89 return not_f(*args, **kwargs) 90 91 return wrapper 92 93 return _defined_if
Only define decorated function if cond
is True
.
96@nb.njit(fastmath=True) 97def strain_increment(dt, velocity_gradient): 98 """Calculate strain increment for a given time increment and velocity gradient. 99 100 Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the 101 “(engineering) shear strain” increment. 102 103 """ 104 return ( 105 np.abs(dt) 106 * np.abs( 107 np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2) 108 ).max() 109 )
Calculate strain increment for a given time increment and velocity gradient.
Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the “(engineering) shear strain” increment.
112@nb.njit 113def apply_gbs( 114 orientations, fractions, gbs_threshold, orientations_prev, n_grains 115) -> tuple[np.ndarray, np.ndarray]: 116 """Apply grain boundary sliding for small grains.""" 117 mask = fractions < (gbs_threshold / n_grains) 118 # _log.debug( 119 # "grain boundary sliding activity (volume percentage): %s", 120 # len(np.nonzero(mask)) / len(fractions), 121 # ) 122 # No rotation: carry over previous orientations. 123 orientations[mask, :, :] = orientations_prev[mask, :, :] 124 fractions[mask] = gbs_threshold / n_grains 125 fractions /= fractions.sum() 126 # _log.debug( 127 # "grain volume fractions: median=%e, min=%e, max=%e, sum=%e", 128 # np.median(fractions), 129 # np.min(fractions), 130 # np.max(fractions), 131 # np.sum(fractions), 132 # ) 133 return orientations, fractions
Apply grain boundary sliding for small grains.
136@nb.njit 137def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 138 """Extract deformation gradient, orientation matrices and grain sizes from y.""" 139 deformation_gradient = y[:9].reshape((3, 3)) 140 orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1) 141 fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None) 142 fractions /= fractions.sum() 143 return deformation_gradient, orientations, fractions
Extract deformation gradient, orientation matrices and grain sizes from y.
146def remove_nans(a): 147 """Remove NaN values from array.""" 148 a = np.asarray(a) 149 return a[~np.isnan(a)]
Remove NaN values from array.
152def remove_dim(a, dim): 153 """Remove all values corresponding to dimension `dim` from an array. 154 155 Note that a `dim` of 0 refers to the “x” values. 156 157 Examples: 158 159 >>> a = [1, 2, 3] 160 >>> remove_dim(a, 0) 161 array([2, 3]) 162 >>> remove_dim(a, 1) 163 array([1, 3]) 164 >>> remove_dim(a, 2) 165 array([1, 2]) 166 167 >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 168 >>> remove_dim(a, 0) 169 array([[5, 6], 170 [8, 9]]) 171 >>> remove_dim(a, 1) 172 array([[1, 3], 173 [7, 9]]) 174 >>> remove_dim(a, 2) 175 array([[1, 2], 176 [4, 5]]) 177 178 """ 179 _a = np.asarray(a) 180 for i, _ in enumerate(_a.shape): 181 _a = np.delete(_a, [dim], axis=i) 182 return _a
Remove all values corresponding to dimension dim
from an array.
Note that a dim
of 0 refers to the “x” values.
Examples:
>>> a = [1, 2, 3]
>>> remove_dim(a, 0)
array([2, 3])
>>> remove_dim(a, 1)
array([1, 3])
>>> remove_dim(a, 2)
array([1, 2])
>>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> remove_dim(a, 0)
array([[5, 6],
[8, 9]])
>>> remove_dim(a, 1)
array([[1, 3],
[7, 9]])
>>> remove_dim(a, 2)
array([[1, 2],
[4, 5]])
185def add_dim(a, dim, val=0): 186 """Add entries of `val` corresponding to dimension `dim` to an array. 187 188 Note that a `dim` of 0 refers to the “x” values. 189 190 Examples: 191 192 >>> a = [1, 2] 193 >>> add_dim(a, 0) 194 array([0, 1, 2]) 195 >>> add_dim(a, 1) 196 array([1, 0, 2]) 197 >>> add_dim(a, 2) 198 array([1, 2, 0]) 199 200 >>> add_dim([1.0, 2.0], 2) 201 array([1., 2., 0.]) 202 203 >>> a = [[1, 2], [3, 4]] 204 >>> add_dim(a, 0) 205 array([[0, 0, 0], 206 [0, 1, 2], 207 [0, 3, 4]]) 208 >>> add_dim(a, 1) 209 array([[1, 0, 2], 210 [0, 0, 0], 211 [3, 0, 4]]) 212 >>> add_dim(a, 2) 213 array([[1, 2, 0], 214 [3, 4, 0], 215 [0, 0, 0]]) 216 217 """ 218 _a = np.asarray(a) 219 for i, _ in enumerate(_a.shape): 220 _a = np.insert(_a, [dim], 0, axis=i) 221 return _a
Add entries of val
corresponding to dimension dim
to an array.
Note that a dim
of 0 refers to the “x” values.
Examples:
>>> a = [1, 2]
>>> add_dim(a, 0)
array([0, 1, 2])
>>> add_dim(a, 1)
array([1, 0, 2])
>>> add_dim(a, 2)
array([1, 2, 0])
>>> add_dim([1.0, 2.0], 2)
array([1., 2., 0.])
>>> a = [[1, 2], [3, 4]]
>>> add_dim(a, 0)
array([[0, 0, 0],
[0, 1, 2],
[0, 3, 4]])
>>> add_dim(a, 1)
array([[1, 0, 2],
[0, 0, 0],
[3, 0, 4]])
>>> add_dim(a, 2)
array([[1, 2, 0],
[3, 4, 0],
[0, 0, 0]])
224def default_ncpus() -> int: 225 """Get a safe default number of CPUs available for multiprocessing. 226 227 On Linux platforms that support it, the method `os.sched_getaffinity()` is used. 228 On Mac OS, the command `sysctl -n hw.ncpu` is used. 229 On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried. 230 If any of these fail, a fallback of 1 is used and a warning is logged. 231 232 """ 233 try: 234 match platform.system(): 235 case "Linux": 236 return len(os.sched_getaffinity(0)) - 1 # May raise AttributeError. 237 case "Darwin": 238 # May raise CalledProcessError. 239 out = subprocess.run( 240 ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True 241 ) 242 return int(out.stdout.strip()) - 1 243 case "Windows": 244 return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1 245 case _: 246 return 1 247 except (AttributeError, subprocess.CalledProcessError, KeyError): 248 return 1
Get a safe default number of CPUs available for multiprocessing.
On Linux platforms that support it, the method os.sched_getaffinity()
is used.
On Mac OS, the command sysctl -n hw.ncpu
is used.
On Windows, the environment variable NUMBER_OF_PROCESSORS
is queried.
If any of these fail, a fallback of 1 is used and a warning is logged.
251def diff_like(a): 252 """Get forward difference of 2D array `a`, with repeated last elements. 253 254 The repeated last elements ensure that output and input arrays have equal shape. 255 256 Examples: 257 258 >>> diff_like(np.array([1, 2, 3, 4, 5])) 259 array([[1, 1, 1, 1, 1]]) 260 261 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]])) 262 array([[1, 1, 1, 1, 1], 263 [2, 3, 3, 1, 1]]) 264 265 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]])) 266 array([[ 1., 1., 1., 1., 1.], 267 [ 2., 3., 3., 1., 1.], 268 [-1., 0., 0., inf, nan]]) 269 270 """ 271 a2 = np.atleast_2d(a) 272 return np.diff( 273 a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1)) 274 )
Get forward difference of 2D array a
, with repeated last elements.
The repeated last elements ensure that output and input arrays have equal shape.
Examples:
>>> diff_like(np.array([1, 2, 3, 4, 5]))
array([[1, 1, 1, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
array([[1, 1, 1, 1, 1],
[2, 3, 3, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
array([[ 1., 1., 1., 1., 1.],
[ 2., 3., 3., 1., 1.],
[-1., 0., 0., inf, nan]])
277def angle_fse_simpleshear(strain): 278 """Get angle of FSE long axis anticlockwise from the X axis in simple shear.""" 279 return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))
Get angle of FSE long axis anticlockwise from the X axis in simple shear.
282def lag_2d_corner_flow(θ): 283 """Get predicted grain orientation lag for 2D corner flow. 284 285 See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222). 286 287 """ 288 _θ = np.ma.masked_less(θ, 1e-15) 289 return (_θ * (_θ**2 + np.cos(_θ) ** 2)) / ( 290 np.tan(_θ) * (_θ**2 + np.cos(_θ) ** 2 - _θ * np.sin(2 * _θ)) 291 )
Get predicted grain orientation lag for 2D corner flow.
See eq. 11 in Kaminski & Ribe (2002).
294@nb.njit(fastmath=True) 295def quat_product(q1, q2): 296 """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.""" 297 return [ 298 *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]), 299 q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]), 300 ]
Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.
303def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs): 304 """Redraw legend on matplotlib axis or figure. 305 306 Transparency is removed from legend symbols. 307 If `fig` is not None and `remove_all` is True, 308 all legends are first removed from the parent figure. 309 Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default, 310 or `matplotlib.figure.Figure.legend` if `fig` is not None. 311 312 If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes 313 instead of taking up space in the original axes. This option requires `fig=None`. 314 315 .. warning:: 316 Note that if `fig` is not `None`, the legend may be cropped from the saved 317 figure due to a Matplotlib bug. In this case, it is required to add the 318 arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`, 319 where `legend` is the object returned by this function. To prevent the legend 320 from consuming axes/subplot space, it is further required to add the lines: 321 `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)` 322 and `fig.set_layout_engine("none")` before saving the figure. 323 324 """ 325 handler_map = { 326 PathCollection: HandlerPathCollection( 327 update_func=_remove_legend_symbol_transparency 328 ), 329 Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency), 330 } 331 if fig is None: 332 legend = ax.get_legend() 333 if legend is not None: 334 handles, labels = ax.get_legend_handles_labels() 335 legend.remove() 336 if legendax is not None: 337 legendax.axis("off") 338 return legendax.legend(handles, labels, handler_map=handler_map, **kwargs) 339 return ax.legend(handler_map=handler_map, **kwargs) 340 else: 341 if legendax is not None: 342 _log.warning("ignoring `legendax` argument which requires `fig=None`") 343 for legend in fig.legends: 344 if legend is not None: 345 legend.remove() 346 if remove_all: 347 for ax in fig.axes: 348 legend = ax.get_legend() 349 if legend is not None: 350 legend.remove() 351 return fig.legend(handler_map=handler_map, **kwargs)
Redraw legend on matplotlib axis or figure.
Transparency is removed from legend symbols.
If fig
is not None and remove_all
is True,
all legends are first removed from the parent figure.
Optional keyword arguments are passed to matplotlib.axes.Axes.legend
by default,
or matplotlib.figure.Figure.legend
if fig
is not None.
If legendax
is not None, the axis legend will be redrawn using the legendax
axes
instead of taking up space in the original axes. This option requires fig=None
.
Note that if fig
is not None
, the legend may be cropped from the saved
figure due to a Matplotlib bug. In this case, it is required to add the
arguments bbox_extra_artists=(legend,)
and bbox_inches="tight"
to savefig
,
where legend
is the object returned by this function. To prevent the legend
from consuming axes/subplot space, it is further required to add the lines:
legend.set_in_layout(False)
, fig.canvas.draw()
, legend.set_layout(True)
and fig.set_layout_engine("none")
before saving the figure.
354def add_subplot_labels( 355 mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs 356): 357 """Add subplot labels to axes mosaic. 358 359 Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels. 360 If `labelmap` is None, the keys in `axs` will be used as the labels by default. 361 362 If `internal` is `False` (default), the axes titles will be used. 363 Otherwise, internal labels will be drawn with `ax.text`, 364 in which case `loc` must be a tuple of floats. 365 366 Any axes in `axs` corresponding to the special key `legend` are skipped. 367 368 """ 369 for txt, ax in mosaic.items(): 370 if txt.lower() == "legend": 371 continue 372 _txt = labelmap[txt] if labelmap is not None else txt 373 if internal: 374 trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans) 375 if isinstance(loc, str): 376 raise ValueError( 377 "'loc' argument must be a sequence of float when 'internal' is 'True'" 378 ) 379 ax.text( 380 *loc, 381 _txt, 382 transform=ax.transAxes + trans, 383 fontsize=fontsize, 384 bbox={ 385 "facecolor": (1.0, 1.0, 1.0, 0.3), 386 "edgecolor": "none", 387 "pad": 3.0, 388 }, 389 ) 390 else: 391 ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)
Add subplot labels to axes mosaic.
Use labelmap
to specify a dictionary that maps keys in mosaic
to subplot labels.
If labelmap
is None, the keys in axs
will be used as the labels by default.
If internal
is False
(default), the axes titles will be used.
Otherwise, internal labels will be drawn with ax.text
,
in which case loc
must be a tuple of floats.
Any axes in axs
corresponding to the special key legend
are skipped.