pydrex.utils
PyDRex: Miscellaneous utility methods.
1"""> PyDRex: Miscellaneous utility methods.""" 2 3import os 4import platform 5import subprocess 6import sys 7from functools import wraps 8 9import dill 10import numba as nb 11import numpy as np 12import scipy.special as sp 13from matplotlib.collections import PathCollection 14from matplotlib.legend_handler import HandlerLine2D, HandlerPathCollection 15from matplotlib.pyplot import Line2D 16from matplotlib.transforms import ScaledTranslation 17 18from pydrex import logger as _log 19 20 21def import_proc_pool() -> tuple: 22 """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`. 23 24 Import a process `Pool` object either from Ray of from Python's stdlib. 25 Both offer the same API, the Ray implementation will be preferred if available. 26 Using the `Pool` provided by Ray allows for distributed memory multiprocessing. 27 28 Returns a tuple containing the `Pool` object and a boolean flag which is `True` if 29 Ray is available. 30 31 """ 32 try: 33 from ray.util.multiprocessing import Pool 34 35 has_ray = True 36 except ImportError: 37 from multiprocessing import Pool 38 39 has_ray = False 40 return Pool, has_ray 41 42 43def in_ci(platform: str) -> bool: 44 """Check if we are in a GitHub runner with the given operating system.""" 45 # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables 46 return sys.platform == platform and os.getenv("CI") is not None 47 48 49class SerializedCallable: 50 """A serialized version of the callable f. 51 52 Serialization is performed using the dill library. The object is safe to pass into 53 `multiprocessing.Pool.map` and its alternatives. 54 55 .. note:: To serialize a lexical closure (i.e. a function defined inside a 56 function), use the `serializable` decorator. 57 58 """ 59 60 def __init__(self, f): 61 self._f = dill.dumps(f, protocol=5, byref=True) 62 63 def __call__(self, *args, **kwargs): 64 return dill.loads(self._f)(*args, **kwargs) 65 66 67def serializable(f): 68 """Make decorated function serializable. 69 70 .. warning:: The decorated function cannot be a method, and it will loose its 71 docstring. It is not possible to use `functools.wraps` to mitigate this. 72 73 """ 74 return SerializedCallable(f) 75 76 77def defined_if(cond): 78 """Only define decorated function if `cond` is `True`.""" 79 80 def _defined_if(f): 81 def not_f(*args, **kwargs): 82 # Throw the same as we would get from `type(undefined_symbol)`. 83 raise NameError(f"name '{f.__name__}' is not defined") 84 85 @wraps(f) 86 def wrapper(*args, **kwargs): 87 if cond: 88 return f(*args, **kwargs) 89 return not_f(*args, **kwargs) 90 91 return wrapper 92 93 return _defined_if 94 95 96def halfspace( 97 age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-6, fit="Korenaga2016" 98): 99 r"""Get halfspace cooling temperature based on the chosen fit. 100 101 $$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$ 102 103 Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction) 104 are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth 105 $z$ is in metres and age $t$ is in seconds. Supported fits are: 106 - ["Korenaga2016"](http://dx.doi.org/10.1002/2016JB013395)¹, which implements $κ(z)$ 107 - "Standard", i.e. $Q = 0$ 108 109 ¹Although the fit is found in the 2016 paper, the equation is discussed as a 110 reference model in [Korenaga et al. 2021](https://doi.org/10.1029/2020JB021528). 111 The thermal diffusivity below 7km depth is hardcoded to 3.47e-7. 112 113 """ 114 match fit: 115 case "Korenaga2016": 116 a1 = 0.602e-3 117 a2 = -6.045e-10 118 adiabatic = a1 * z + a2 * z**2 119 if z < 7: 120 κ = 3.45e-7 121 else: 122 b0 = -1.255 123 b1 = 9.944 124 b2 = -25.0619 125 b3 = 32.2944 126 b4 = -22.2017 127 b5 = 7.7336 128 b6 = -1.0622 129 coeffs = (b0, b1, b2, b3, b4, b5, b6) 130 z_ref = 1e5 131 κ_0 = diffusivity 132 κ = κ_0 * np.sum( 133 [b * (z / z_ref) ** (n / 2) for n, b in enumerate(coeffs)] 134 ) 135 case "Standard": 136 κ = diffusivity 137 adiabatic = 0.0 138 case _: 139 raise ValueError(f"unsupported fit '{fit}'") 140 return surface_temp + diff_temp * sp.erf(z / (2 * np.sqrt(κ * age))) + adiabatic 141 142 143@nb.njit(fastmath=True) 144def strain_increment(dt, velocity_gradient): 145 """Calculate strain increment for a given time increment and velocity gradient. 146 147 Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the 148 “(engineering) shear strain” increment. 149 150 """ 151 return ( 152 np.abs(dt) 153 * np.abs( 154 np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2) 155 ).max() 156 ) 157 158 159@nb.njit 160def apply_gbs( 161 orientations, fractions, gbs_threshold, orientations_prev, n_grains 162) -> tuple[np.ndarray, np.ndarray]: 163 """Apply grain boundary sliding for small grains.""" 164 mask = fractions < (gbs_threshold / n_grains) 165 # _log.debug( 166 # "grain boundary sliding activity (volume percentage): %s", 167 # len(np.nonzero(mask)) / len(fractions), 168 # ) 169 # No rotation: carry over previous orientations. 170 orientations[mask, :, :] = orientations_prev[mask, :, :] 171 fractions[mask] = gbs_threshold / n_grains 172 fractions /= fractions.sum() 173 # _log.debug( 174 # "grain volume fractions: median=%e, min=%e, max=%e, sum=%e", 175 # np.median(fractions), 176 # np.min(fractions), 177 # np.max(fractions), 178 # np.sum(fractions), 179 # ) 180 return orientations, fractions 181 182 183@nb.njit 184def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 185 """Extract deformation gradient, orientation matrices and grain sizes from y.""" 186 deformation_gradient = y[:9].reshape((3, 3)) 187 orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1) 188 fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None) 189 fractions /= fractions.sum() 190 return deformation_gradient, orientations, fractions 191 192 193def pad_with(a, x=np.nan): 194 """Pad a list of arrays with `x` and return as a new 2D array with regular shape. 195 196 >>> pad_with([[1, 2, 3], [4, 5], [6]]) 197 array([[ 1., 2., 3.], 198 [ 4., 5., nan], 199 [ 6., nan, nan]]) 200 >>> pad_with([[1, 2, 3], [4, 5], [6]], x=0) 201 array([[1, 2, 3], 202 [4, 5, 0], 203 [6, 0, 0]]) 204 >>> pad_with([[1, 2, 3]]) 205 array([[1., 2., 3.]]) 206 >>> pad_with([[1, 2, 3]], x=0) 207 array([[1, 2, 3]]) 208 209 """ 210 longest = max([len(d) for d in a]) 211 out = np.full((len(a), longest), x) 212 for i, d in enumerate(a): 213 out[i, : len(d)] = d 214 return out 215 216 217def remove_nans(a): 218 """Remove NaN values from array.""" 219 a = np.asarray(a) 220 return a[~np.isnan(a)] 221 222 223def remove_dim(a, dim): 224 """Remove all values corresponding to dimension `dim` from an array. 225 226 Note that a `dim` of 0 refers to the “x” values. 227 228 Examples: 229 230 >>> a = [1, 2, 3] 231 >>> remove_dim(a, 0) 232 array([2, 3]) 233 >>> remove_dim(a, 1) 234 array([1, 3]) 235 >>> remove_dim(a, 2) 236 array([1, 2]) 237 238 >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 239 >>> remove_dim(a, 0) 240 array([[5, 6], 241 [8, 9]]) 242 >>> remove_dim(a, 1) 243 array([[1, 3], 244 [7, 9]]) 245 >>> remove_dim(a, 2) 246 array([[1, 2], 247 [4, 5]]) 248 249 """ 250 _a = np.asarray(a) 251 for i, _ in enumerate(_a.shape): 252 _a = np.delete(_a, [dim], axis=i) 253 return _a 254 255 256def add_dim(a, dim, val=0): 257 """Add entries of `val` corresponding to dimension `dim` to an array. 258 259 Note that a `dim` of 0 refers to the “x” values. 260 261 Examples: 262 263 >>> a = [1, 2] 264 >>> add_dim(a, 0) 265 array([0, 1, 2]) 266 >>> add_dim(a, 1) 267 array([1, 0, 2]) 268 >>> add_dim(a, 2) 269 array([1, 2, 0]) 270 271 >>> add_dim([1.0, 2.0], 2) 272 array([1., 2., 0.]) 273 274 >>> a = [[1, 2], [3, 4]] 275 >>> add_dim(a, 0) 276 array([[0, 0, 0], 277 [0, 1, 2], 278 [0, 3, 4]]) 279 >>> add_dim(a, 1) 280 array([[1, 0, 2], 281 [0, 0, 0], 282 [3, 0, 4]]) 283 >>> add_dim(a, 2) 284 array([[1, 2, 0], 285 [3, 4, 0], 286 [0, 0, 0]]) 287 288 """ 289 _a = np.asarray(a) 290 for i, _ in enumerate(_a.shape): 291 _a = np.insert(_a, [dim], 0, axis=i) 292 return _a 293 294 295def default_ncpus() -> int: 296 """Get a safe default number of CPUs available for multiprocessing. 297 298 On Linux platforms that support it, the method `os.sched_getaffinity()` is used. 299 On Mac OS, the command `sysctl -n hw.ncpu` is used. 300 On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried. 301 If any of these fail, a fallback of 1 is used and a warning is logged. 302 303 """ 304 try: 305 match platform.system(): 306 case "Linux": 307 return len(os.sched_getaffinity(0)) - 1 # May raise AttributeError. 308 case "Darwin": 309 # May raise CalledProcessError. 310 out = subprocess.run( 311 ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True 312 ) 313 return int(out.stdout.strip()) - 1 314 case "Windows": 315 return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1 316 case _: 317 return 1 318 except (AttributeError, subprocess.CalledProcessError, KeyError): 319 return 1 320 321 322def diff_like(a): 323 """Get forward difference of 2D array `a`, with repeated last elements. 324 325 The repeated last elements ensure that output and input arrays have equal shape. 326 327 Examples: 328 329 >>> diff_like(np.array([1, 2, 3, 4, 5])) 330 array([[1, 1, 1, 1, 1]]) 331 332 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]])) 333 array([[1, 1, 1, 1, 1], 334 [2, 3, 3, 1, 1]]) 335 336 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]])) 337 array([[ 1., 1., 1., 1., 1.], 338 [ 2., 3., 3., 1., 1.], 339 [-1., 0., 0., inf, nan]]) 340 341 """ 342 a2 = np.atleast_2d(a) 343 return np.diff( 344 a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1)) 345 ) 346 347 348def angle_fse_simpleshear(strain): 349 """Get angle of FSE long axis anticlockwise from the X axis in simple shear.""" 350 return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain)) 351 352 353def lag_2d_corner_flow(θ): 354 """Get predicted grain orientation lag for 2D corner flow. 355 356 See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222). 357 358 """ 359 _θ = np.ma.masked_less(θ, 1e-15) 360 return (_θ * (_θ**2 + np.cos(_θ) ** 2)) / ( 361 np.tan(_θ) * (_θ**2 + np.cos(_θ) ** 2 - _θ * np.sin(2 * _θ)) 362 ) 363 364 365@nb.njit(fastmath=True) 366def quat_product(q1, q2): 367 """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.""" 368 return [ 369 *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]), 370 q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]), 371 ] 372 373 374def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs): 375 """Redraw legend on matplotlib axis or figure. 376 377 Transparency is removed from legend symbols. 378 If `fig` is not None and `remove_all` is True, 379 all legends are first removed from the parent figure. 380 Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default, 381 or `matplotlib.figure.Figure.legend` if `fig` is not None. 382 383 If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes 384 instead of taking up space in the original axes. This option requires `fig=None`. 385 386 .. warning:: 387 Note that if `fig` is not `None`, the legend may be cropped from the saved 388 figure due to a Matplotlib bug. In this case, it is required to add the 389 arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`, 390 where `legend` is the object returned by this function. To prevent the legend 391 from consuming axes/subplot space, it is further required to add the lines: 392 `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)` 393 and `fig.set_layout_engine("none")` before saving the figure. 394 395 """ 396 handler_map = { 397 PathCollection: HandlerPathCollection( 398 update_func=_remove_legend_symbol_transparency 399 ), 400 Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency), 401 } 402 if fig is None: 403 legend = ax.get_legend() 404 if legend is not None: 405 handles, labels = ax.get_legend_handles_labels() 406 legend.remove() 407 if legendax is not None: 408 legendax.axis("off") 409 return legendax.legend(handles, labels, handler_map=handler_map, **kwargs) 410 return ax.legend(handler_map=handler_map, **kwargs) 411 else: 412 if legendax is not None: 413 _log.warning("ignoring `legendax` argument which requires `fig=None`") 414 for legend in fig.legends: 415 if legend is not None: 416 legend.remove() 417 if remove_all: 418 for ax in fig.axes: 419 legend = ax.get_legend() 420 if legend is not None: 421 legend.remove() 422 return fig.legend(handler_map=handler_map, **kwargs) 423 424 425def add_subplot_labels( 426 mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs 427): 428 """Add subplot labels to axes mosaic. 429 430 Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels. 431 If `labelmap` is None, the keys in `axs` will be used as the labels by default. 432 433 If `internal` is `False` (default), the axes titles will be used. 434 Otherwise, internal labels will be drawn with `ax.text`, 435 in which case `loc` must be a tuple of floats. 436 437 Any axes in `axs` corresponding to the special key `legend` are skipped. 438 439 """ 440 for txt, ax in mosaic.items(): 441 if txt.lower() == "legend": 442 continue 443 _txt = labelmap[txt] if labelmap is not None else txt 444 if internal: 445 trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans) 446 if isinstance(loc, str): 447 raise ValueError( 448 "'loc' argument must be a sequence of float when 'internal' is 'True'" 449 ) 450 ax.text( 451 *loc, 452 _txt, 453 transform=ax.transAxes + trans, 454 fontsize=fontsize, 455 bbox={ 456 "facecolor": (1.0, 1.0, 1.0, 0.3), 457 "edgecolor": "none", 458 "pad": 3.0, 459 }, 460 ) 461 else: 462 ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs) 463 464 465def _remove_legend_symbol_transparency(handle, orig): 466 """Remove transparency from symbols used in a Matplotlib legend.""" 467 # https://stackoverflow.com/a/59629242/12519962 468 handle.update_from(orig) 469 handle.set_alpha(1)
22def import_proc_pool() -> tuple: 23 """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`. 24 25 Import a process `Pool` object either from Ray of from Python's stdlib. 26 Both offer the same API, the Ray implementation will be preferred if available. 27 Using the `Pool` provided by Ray allows for distributed memory multiprocessing. 28 29 Returns a tuple containing the `Pool` object and a boolean flag which is `True` if 30 Ray is available. 31 32 """ 33 try: 34 from ray.util.multiprocessing import Pool 35 36 has_ray = True 37 except ImportError: 38 from multiprocessing import Pool 39 40 has_ray = False 41 return Pool, has_ray
Import either ray.util.multiprocessing.Pool
or multiprocessing.Pool
.
Import a process Pool
object either from Ray of from Python's stdlib.
Both offer the same API, the Ray implementation will be preferred if available.
Using the Pool
provided by Ray allows for distributed memory multiprocessing.
Returns a tuple containing the Pool
object and a boolean flag which is True
if
Ray is available.
44def in_ci(platform: str) -> bool: 45 """Check if we are in a GitHub runner with the given operating system.""" 46 # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables 47 return sys.platform == platform and os.getenv("CI") is not None
Check if we are in a GitHub runner with the given operating system.
50class SerializedCallable: 51 """A serialized version of the callable f. 52 53 Serialization is performed using the dill library. The object is safe to pass into 54 `multiprocessing.Pool.map` and its alternatives. 55 56 .. note:: To serialize a lexical closure (i.e. a function defined inside a 57 function), use the `serializable` decorator. 58 59 """ 60 61 def __init__(self, f): 62 self._f = dill.dumps(f, protocol=5, byref=True) 63 64 def __call__(self, *args, **kwargs): 65 return dill.loads(self._f)(*args, **kwargs)
A serialized version of the callable f.
Serialization is performed using the dill library. The object is safe to pass into
multiprocessing.Pool.map
and its alternatives.
To serialize a lexical closure (i.e. a function defined inside a
function), use the serializable
decorator.
68def serializable(f): 69 """Make decorated function serializable. 70 71 .. warning:: The decorated function cannot be a method, and it will loose its 72 docstring. It is not possible to use `functools.wraps` to mitigate this. 73 74 """ 75 return SerializedCallable(f)
Make decorated function serializable.
The decorated function cannot be a method, and it will loose its
docstring. It is not possible to use functools.wraps
to mitigate this.
78def defined_if(cond): 79 """Only define decorated function if `cond` is `True`.""" 80 81 def _defined_if(f): 82 def not_f(*args, **kwargs): 83 # Throw the same as we would get from `type(undefined_symbol)`. 84 raise NameError(f"name '{f.__name__}' is not defined") 85 86 @wraps(f) 87 def wrapper(*args, **kwargs): 88 if cond: 89 return f(*args, **kwargs) 90 return not_f(*args, **kwargs) 91 92 return wrapper 93 94 return _defined_if
Only define decorated function if cond
is True
.
97def halfspace( 98 age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-6, fit="Korenaga2016" 99): 100 r"""Get halfspace cooling temperature based on the chosen fit. 101 102 $$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$ 103 104 Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction) 105 are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth 106 $z$ is in metres and age $t$ is in seconds. Supported fits are: 107 - ["Korenaga2016"](http://dx.doi.org/10.1002/2016JB013395)¹, which implements $κ(z)$ 108 - "Standard", i.e. $Q = 0$ 109 110 ¹Although the fit is found in the 2016 paper, the equation is discussed as a 111 reference model in [Korenaga et al. 2021](https://doi.org/10.1029/2020JB021528). 112 The thermal diffusivity below 7km depth is hardcoded to 3.47e-7. 113 114 """ 115 match fit: 116 case "Korenaga2016": 117 a1 = 0.602e-3 118 a2 = -6.045e-10 119 adiabatic = a1 * z + a2 * z**2 120 if z < 7: 121 κ = 3.45e-7 122 else: 123 b0 = -1.255 124 b1 = 9.944 125 b2 = -25.0619 126 b3 = 32.2944 127 b4 = -22.2017 128 b5 = 7.7336 129 b6 = -1.0622 130 coeffs = (b0, b1, b2, b3, b4, b5, b6) 131 z_ref = 1e5 132 κ_0 = diffusivity 133 κ = κ_0 * np.sum( 134 [b * (z / z_ref) ** (n / 2) for n, b in enumerate(coeffs)] 135 ) 136 case "Standard": 137 κ = diffusivity 138 adiabatic = 0.0 139 case _: 140 raise ValueError(f"unsupported fit '{fit}'") 141 return surface_temp + diff_temp * sp.erf(z / (2 * np.sqrt(κ * age))) + adiabatic
Get halfspace cooling temperature based on the chosen fit.
$$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$
Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction) are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth $z$ is in metres and age $t$ is in seconds. Supported fits are:
- "Korenaga2016"¹, which implements $κ(z)$
- "Standard", i.e. $Q = 0$
¹Although the fit is found in the 2016 paper, the equation is discussed as a reference model in Korenaga et al. 2021. The thermal diffusivity below 7km depth is hardcoded to 3.47e-7.
144@nb.njit(fastmath=True) 145def strain_increment(dt, velocity_gradient): 146 """Calculate strain increment for a given time increment and velocity gradient. 147 148 Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the 149 “(engineering) shear strain” increment. 150 151 """ 152 return ( 153 np.abs(dt) 154 * np.abs( 155 np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2) 156 ).max() 157 )
Calculate strain increment for a given time increment and velocity gradient.
Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the “(engineering) shear strain” increment.
160@nb.njit 161def apply_gbs( 162 orientations, fractions, gbs_threshold, orientations_prev, n_grains 163) -> tuple[np.ndarray, np.ndarray]: 164 """Apply grain boundary sliding for small grains.""" 165 mask = fractions < (gbs_threshold / n_grains) 166 # _log.debug( 167 # "grain boundary sliding activity (volume percentage): %s", 168 # len(np.nonzero(mask)) / len(fractions), 169 # ) 170 # No rotation: carry over previous orientations. 171 orientations[mask, :, :] = orientations_prev[mask, :, :] 172 fractions[mask] = gbs_threshold / n_grains 173 fractions /= fractions.sum() 174 # _log.debug( 175 # "grain volume fractions: median=%e, min=%e, max=%e, sum=%e", 176 # np.median(fractions), 177 # np.min(fractions), 178 # np.max(fractions), 179 # np.sum(fractions), 180 # ) 181 return orientations, fractions
Apply grain boundary sliding for small grains.
184@nb.njit 185def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 186 """Extract deformation gradient, orientation matrices and grain sizes from y.""" 187 deformation_gradient = y[:9].reshape((3, 3)) 188 orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1) 189 fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None) 190 fractions /= fractions.sum() 191 return deformation_gradient, orientations, fractions
Extract deformation gradient, orientation matrices and grain sizes from y.
194def pad_with(a, x=np.nan): 195 """Pad a list of arrays with `x` and return as a new 2D array with regular shape. 196 197 >>> pad_with([[1, 2, 3], [4, 5], [6]]) 198 array([[ 1., 2., 3.], 199 [ 4., 5., nan], 200 [ 6., nan, nan]]) 201 >>> pad_with([[1, 2, 3], [4, 5], [6]], x=0) 202 array([[1, 2, 3], 203 [4, 5, 0], 204 [6, 0, 0]]) 205 >>> pad_with([[1, 2, 3]]) 206 array([[1., 2., 3.]]) 207 >>> pad_with([[1, 2, 3]], x=0) 208 array([[1, 2, 3]]) 209 210 """ 211 longest = max([len(d) for d in a]) 212 out = np.full((len(a), longest), x) 213 for i, d in enumerate(a): 214 out[i, : len(d)] = d 215 return out
Pad a list of arrays with x
and return as a new 2D array with regular shape.
>>> pad_with([[1, 2, 3], [4, 5], [6]])
array([[ 1., 2., 3.],
[ 4., 5., nan],
[ 6., nan, nan]])
>>> pad_with([[1, 2, 3], [4, 5], [6]], x=0)
array([[1, 2, 3],
[4, 5, 0],
[6, 0, 0]])
>>> pad_with([[1, 2, 3]])
array([[1., 2., 3.]])
>>> pad_with([[1, 2, 3]], x=0)
array([[1, 2, 3]])
218def remove_nans(a): 219 """Remove NaN values from array.""" 220 a = np.asarray(a) 221 return a[~np.isnan(a)]
Remove NaN values from array.
224def remove_dim(a, dim): 225 """Remove all values corresponding to dimension `dim` from an array. 226 227 Note that a `dim` of 0 refers to the “x” values. 228 229 Examples: 230 231 >>> a = [1, 2, 3] 232 >>> remove_dim(a, 0) 233 array([2, 3]) 234 >>> remove_dim(a, 1) 235 array([1, 3]) 236 >>> remove_dim(a, 2) 237 array([1, 2]) 238 239 >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 240 >>> remove_dim(a, 0) 241 array([[5, 6], 242 [8, 9]]) 243 >>> remove_dim(a, 1) 244 array([[1, 3], 245 [7, 9]]) 246 >>> remove_dim(a, 2) 247 array([[1, 2], 248 [4, 5]]) 249 250 """ 251 _a = np.asarray(a) 252 for i, _ in enumerate(_a.shape): 253 _a = np.delete(_a, [dim], axis=i) 254 return _a
Remove all values corresponding to dimension dim
from an array.
Note that a dim
of 0 refers to the “x” values.
Examples:
>>> a = [1, 2, 3]
>>> remove_dim(a, 0)
array([2, 3])
>>> remove_dim(a, 1)
array([1, 3])
>>> remove_dim(a, 2)
array([1, 2])
>>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> remove_dim(a, 0)
array([[5, 6],
[8, 9]])
>>> remove_dim(a, 1)
array([[1, 3],
[7, 9]])
>>> remove_dim(a, 2)
array([[1, 2],
[4, 5]])
257def add_dim(a, dim, val=0): 258 """Add entries of `val` corresponding to dimension `dim` to an array. 259 260 Note that a `dim` of 0 refers to the “x” values. 261 262 Examples: 263 264 >>> a = [1, 2] 265 >>> add_dim(a, 0) 266 array([0, 1, 2]) 267 >>> add_dim(a, 1) 268 array([1, 0, 2]) 269 >>> add_dim(a, 2) 270 array([1, 2, 0]) 271 272 >>> add_dim([1.0, 2.0], 2) 273 array([1., 2., 0.]) 274 275 >>> a = [[1, 2], [3, 4]] 276 >>> add_dim(a, 0) 277 array([[0, 0, 0], 278 [0, 1, 2], 279 [0, 3, 4]]) 280 >>> add_dim(a, 1) 281 array([[1, 0, 2], 282 [0, 0, 0], 283 [3, 0, 4]]) 284 >>> add_dim(a, 2) 285 array([[1, 2, 0], 286 [3, 4, 0], 287 [0, 0, 0]]) 288 289 """ 290 _a = np.asarray(a) 291 for i, _ in enumerate(_a.shape): 292 _a = np.insert(_a, [dim], 0, axis=i) 293 return _a
Add entries of val
corresponding to dimension dim
to an array.
Note that a dim
of 0 refers to the “x” values.
Examples:
>>> a = [1, 2]
>>> add_dim(a, 0)
array([0, 1, 2])
>>> add_dim(a, 1)
array([1, 0, 2])
>>> add_dim(a, 2)
array([1, 2, 0])
>>> add_dim([1.0, 2.0], 2)
array([1., 2., 0.])
>>> a = [[1, 2], [3, 4]]
>>> add_dim(a, 0)
array([[0, 0, 0],
[0, 1, 2],
[0, 3, 4]])
>>> add_dim(a, 1)
array([[1, 0, 2],
[0, 0, 0],
[3, 0, 4]])
>>> add_dim(a, 2)
array([[1, 2, 0],
[3, 4, 0],
[0, 0, 0]])
296def default_ncpus() -> int: 297 """Get a safe default number of CPUs available for multiprocessing. 298 299 On Linux platforms that support it, the method `os.sched_getaffinity()` is used. 300 On Mac OS, the command `sysctl -n hw.ncpu` is used. 301 On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried. 302 If any of these fail, a fallback of 1 is used and a warning is logged. 303 304 """ 305 try: 306 match platform.system(): 307 case "Linux": 308 return len(os.sched_getaffinity(0)) - 1 # May raise AttributeError. 309 case "Darwin": 310 # May raise CalledProcessError. 311 out = subprocess.run( 312 ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True 313 ) 314 return int(out.stdout.strip()) - 1 315 case "Windows": 316 return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1 317 case _: 318 return 1 319 except (AttributeError, subprocess.CalledProcessError, KeyError): 320 return 1
Get a safe default number of CPUs available for multiprocessing.
On Linux platforms that support it, the method os.sched_getaffinity()
is used.
On Mac OS, the command sysctl -n hw.ncpu
is used.
On Windows, the environment variable NUMBER_OF_PROCESSORS
is queried.
If any of these fail, a fallback of 1 is used and a warning is logged.
323def diff_like(a): 324 """Get forward difference of 2D array `a`, with repeated last elements. 325 326 The repeated last elements ensure that output and input arrays have equal shape. 327 328 Examples: 329 330 >>> diff_like(np.array([1, 2, 3, 4, 5])) 331 array([[1, 1, 1, 1, 1]]) 332 333 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]])) 334 array([[1, 1, 1, 1, 1], 335 [2, 3, 3, 1, 1]]) 336 337 >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]])) 338 array([[ 1., 1., 1., 1., 1.], 339 [ 2., 3., 3., 1., 1.], 340 [-1., 0., 0., inf, nan]]) 341 342 """ 343 a2 = np.atleast_2d(a) 344 return np.diff( 345 a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1)) 346 )
Get forward difference of 2D array a
, with repeated last elements.
The repeated last elements ensure that output and input arrays have equal shape.
Examples:
>>> diff_like(np.array([1, 2, 3, 4, 5]))
array([[1, 1, 1, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
array([[1, 1, 1, 1, 1],
[2, 3, 3, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
array([[ 1., 1., 1., 1., 1.],
[ 2., 3., 3., 1., 1.],
[-1., 0., 0., inf, nan]])
349def angle_fse_simpleshear(strain): 350 """Get angle of FSE long axis anticlockwise from the X axis in simple shear.""" 351 return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))
Get angle of FSE long axis anticlockwise from the X axis in simple shear.
354def lag_2d_corner_flow(θ): 355 """Get predicted grain orientation lag for 2D corner flow. 356 357 See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222). 358 359 """ 360 _θ = np.ma.masked_less(θ, 1e-15) 361 return (_θ * (_θ**2 + np.cos(_θ) ** 2)) / ( 362 np.tan(_θ) * (_θ**2 + np.cos(_θ) ** 2 - _θ * np.sin(2 * _θ)) 363 )
Get predicted grain orientation lag for 2D corner flow.
See eq. 11 in Kaminski & Ribe (2002).
366@nb.njit(fastmath=True) 367def quat_product(q1, q2): 368 """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.""" 369 return [ 370 *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]), 371 q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]), 372 ]
Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.
375def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs): 376 """Redraw legend on matplotlib axis or figure. 377 378 Transparency is removed from legend symbols. 379 If `fig` is not None and `remove_all` is True, 380 all legends are first removed from the parent figure. 381 Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default, 382 or `matplotlib.figure.Figure.legend` if `fig` is not None. 383 384 If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes 385 instead of taking up space in the original axes. This option requires `fig=None`. 386 387 .. warning:: 388 Note that if `fig` is not `None`, the legend may be cropped from the saved 389 figure due to a Matplotlib bug. In this case, it is required to add the 390 arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`, 391 where `legend` is the object returned by this function. To prevent the legend 392 from consuming axes/subplot space, it is further required to add the lines: 393 `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)` 394 and `fig.set_layout_engine("none")` before saving the figure. 395 396 """ 397 handler_map = { 398 PathCollection: HandlerPathCollection( 399 update_func=_remove_legend_symbol_transparency 400 ), 401 Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency), 402 } 403 if fig is None: 404 legend = ax.get_legend() 405 if legend is not None: 406 handles, labels = ax.get_legend_handles_labels() 407 legend.remove() 408 if legendax is not None: 409 legendax.axis("off") 410 return legendax.legend(handles, labels, handler_map=handler_map, **kwargs) 411 return ax.legend(handler_map=handler_map, **kwargs) 412 else: 413 if legendax is not None: 414 _log.warning("ignoring `legendax` argument which requires `fig=None`") 415 for legend in fig.legends: 416 if legend is not None: 417 legend.remove() 418 if remove_all: 419 for ax in fig.axes: 420 legend = ax.get_legend() 421 if legend is not None: 422 legend.remove() 423 return fig.legend(handler_map=handler_map, **kwargs)
Redraw legend on matplotlib axis or figure.
Transparency is removed from legend symbols.
If fig
is not None and remove_all
is True,
all legends are first removed from the parent figure.
Optional keyword arguments are passed to matplotlib.axes.Axes.legend
by default,
or matplotlib.figure.Figure.legend
if fig
is not None.
If legendax
is not None, the axis legend will be redrawn using the legendax
axes
instead of taking up space in the original axes. This option requires fig=None
.
Note that if fig
is not None
, the legend may be cropped from the saved
figure due to a Matplotlib bug. In this case, it is required to add the
arguments bbox_extra_artists=(legend,)
and bbox_inches="tight"
to savefig
,
where legend
is the object returned by this function. To prevent the legend
from consuming axes/subplot space, it is further required to add the lines:
legend.set_in_layout(False)
, fig.canvas.draw()
, legend.set_layout(True)
and fig.set_layout_engine("none")
before saving the figure.
426def add_subplot_labels( 427 mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs 428): 429 """Add subplot labels to axes mosaic. 430 431 Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels. 432 If `labelmap` is None, the keys in `axs` will be used as the labels by default. 433 434 If `internal` is `False` (default), the axes titles will be used. 435 Otherwise, internal labels will be drawn with `ax.text`, 436 in which case `loc` must be a tuple of floats. 437 438 Any axes in `axs` corresponding to the special key `legend` are skipped. 439 440 """ 441 for txt, ax in mosaic.items(): 442 if txt.lower() == "legend": 443 continue 444 _txt = labelmap[txt] if labelmap is not None else txt 445 if internal: 446 trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans) 447 if isinstance(loc, str): 448 raise ValueError( 449 "'loc' argument must be a sequence of float when 'internal' is 'True'" 450 ) 451 ax.text( 452 *loc, 453 _txt, 454 transform=ax.transAxes + trans, 455 fontsize=fontsize, 456 bbox={ 457 "facecolor": (1.0, 1.0, 1.0, 0.3), 458 "edgecolor": "none", 459 "pad": 3.0, 460 }, 461 ) 462 else: 463 ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)
Add subplot labels to axes mosaic.
Use labelmap
to specify a dictionary that maps keys in mosaic
to subplot labels.
If labelmap
is None, the keys in axs
will be used as the labels by default.
If internal
is False
(default), the axes titles will be used.
Otherwise, internal labels will be drawn with ax.text
,
in which case loc
must be a tuple of floats.
Any axes in axs
corresponding to the special key legend
are skipped.