pydrex.utils

PyDRex: Miscellaneous utility methods.

  1"""> PyDRex: Miscellaneous utility methods."""
  2
  3import os
  4import platform
  5import subprocess
  6import sys
  7from functools import wraps
  8
  9import dill
 10import numba as nb
 11import numpy as np
 12import scipy.special as sp
 13from matplotlib.collections import PathCollection
 14from matplotlib.legend_handler import HandlerLine2D, HandlerPathCollection
 15from matplotlib.pyplot import Line2D
 16from matplotlib.transforms import ScaledTranslation
 17
 18from pydrex import logger as _log
 19
 20
 21def import_proc_pool() -> tuple:
 22    """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`.
 23
 24    Import a process `Pool` object either from Ray of from Python's stdlib.
 25    Both offer the same API, the Ray implementation will be preferred if available.
 26    Using the `Pool` provided by Ray allows for distributed memory multiprocessing.
 27
 28    Returns a tuple containing the `Pool` object and a boolean flag which is `True` if
 29    Ray is available.
 30
 31    """
 32    try:
 33        from ray.util.multiprocessing import Pool
 34
 35        has_ray = True
 36    except ImportError:
 37        from multiprocessing import Pool
 38
 39        has_ray = False
 40    return Pool, has_ray
 41
 42
 43def in_ci(platform: str) -> bool:
 44    """Check if we are in a GitHub runner with the given operating system."""
 45    # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
 46    return sys.platform == platform and os.getenv("CI") is not None
 47
 48
 49class SerializedCallable:
 50    """A serialized version of the callable f.
 51
 52    Serialization is performed using the dill library. The object is safe to pass into
 53    `multiprocessing.Pool.map` and its alternatives.
 54
 55    .. note:: To serialize a lexical closure (i.e. a function defined inside a
 56        function), use the `serializable` decorator.
 57
 58    """
 59
 60    def __init__(self, f):
 61        self._f = dill.dumps(f, protocol=5, byref=True)
 62
 63    def __call__(self, *args, **kwargs):
 64        return dill.loads(self._f)(*args, **kwargs)
 65
 66
 67def serializable(f):
 68    """Make decorated function serializable.
 69
 70    .. warning:: The decorated function cannot be a method, and it will loose its
 71        docstring. It is not possible to use `functools.wraps` to mitigate this.
 72
 73    """
 74    return SerializedCallable(f)
 75
 76
 77def defined_if(cond):
 78    """Only define decorated function if `cond` is `True`."""
 79
 80    def _defined_if(f):
 81        def not_f(*args, **kwargs):
 82            # Throw the same as we would get from `type(undefined_symbol)`.
 83            raise NameError(f"name '{f.__name__}' is not defined")
 84
 85        @wraps(f)
 86        def wrapper(*args, **kwargs):
 87            if cond:
 88                return f(*args, **kwargs)
 89            return not_f(*args, **kwargs)
 90
 91        return wrapper
 92
 93    return _defined_if
 94
 95
 96def halfspace(
 97    age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-6, fit="Korenaga2016"
 98):
 99    r"""Get halfspace cooling temperature based on the chosen fit.
100
101    $$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$
102
103    Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction)
104    are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth
105    $z$ is in metres and age $t$ is in seconds. Supported fits are:
106    - ["Korenaga2016"](http://dx.doi.org/10.1002/2016JB013395)¹, which implements $κ(z)$
107    - "Standard", i.e. $Q = 0$
108
109    ¹Although the fit is found in the 2016 paper, the equation is discussed as a
110    reference model in [Korenaga et al. 2021](https://doi.org/10.1029/2020JB021528).
111    The thermal diffusivity below 7km depth is hardcoded to 3.47e-7.
112
113    """
114    match fit:
115        case "Korenaga2016":
116            a1 = 0.602e-3
117            a2 = -6.045e-10
118            adiabatic = a1 * z + a2 * z**2
119            if z < 7:
120                κ = 3.45e-7
121            else:
122                b0 = -1.255
123                b1 = 9.944
124                b2 = -25.0619
125                b3 = 32.2944
126                b4 = -22.2017
127                b5 = 7.7336
128                b6 = -1.0622
129                coeffs = (b0, b1, b2, b3, b4, b5, b6)
130                z_ref = 1e5
131                κ_0 = diffusivity
132                κ = κ_0 * np.sum(
133                    [b * (z / z_ref) ** (n / 2) for n, b in enumerate(coeffs)]
134                )
135        case "Standard":
136            κ = diffusivity
137            adiabatic = 0.0
138        case _:
139            raise ValueError(f"unsupported fit '{fit}'")
140    return surface_temp + diff_temp * sp.erf(z / (2 * np.sqrt(κ * age))) + adiabatic
141
142
143@nb.njit(fastmath=True)
144def strain_increment(dt, velocity_gradient):
145    """Calculate strain increment for a given time increment and velocity gradient.
146
147    Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the
148    “(engineering) shear strain” increment.
149
150    """
151    return (
152        np.abs(dt)
153        * np.abs(
154            np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2)
155        ).max()
156    )
157
158
159@nb.njit
160def apply_gbs(
161    orientations, fractions, gbs_threshold, orientations_prev, n_grains
162) -> tuple[np.ndarray, np.ndarray]:
163    """Apply grain boundary sliding for small grains."""
164    mask = fractions < (gbs_threshold / n_grains)
165    # _log.debug(
166    #     "grain boundary sliding activity (volume percentage): %s",
167    #     len(np.nonzero(mask)) / len(fractions),
168    # )
169    # No rotation: carry over previous orientations.
170    orientations[mask, :, :] = orientations_prev[mask, :, :]
171    fractions[mask] = gbs_threshold / n_grains
172    fractions /= fractions.sum()
173    # _log.debug(
174    #     "grain volume fractions: median=%e, min=%e, max=%e, sum=%e",
175    #     np.median(fractions),
176    #     np.min(fractions),
177    #     np.max(fractions),
178    #     np.sum(fractions),
179    # )
180    return orientations, fractions
181
182
183@nb.njit
184def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
185    """Extract deformation gradient, orientation matrices and grain sizes from y."""
186    deformation_gradient = y[:9].reshape((3, 3))
187    orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1)
188    fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None)
189    fractions /= fractions.sum()
190    return deformation_gradient, orientations, fractions
191
192
193def pad_with(a, x=np.nan):
194    """Pad a list of arrays with `x` and return as a new 2D array with regular shape.
195
196    >>> pad_with([[1, 2, 3], [4, 5], [6]])
197    array([[ 1.,  2.,  3.],
198           [ 4.,  5., nan],
199           [ 6., nan, nan]])
200    >>> pad_with([[1, 2, 3], [4, 5], [6]], x=0)
201    array([[1, 2, 3],
202           [4, 5, 0],
203           [6, 0, 0]])
204    >>> pad_with([[1, 2, 3]])
205    array([[1., 2., 3.]])
206    >>> pad_with([[1, 2, 3]], x=0)
207    array([[1, 2, 3]])
208
209    """
210    longest = max([len(d) for d in a])
211    out = np.full((len(a), longest), x)
212    for i, d in enumerate(a):
213        out[i, : len(d)] = d
214    return out
215
216
217def remove_nans(a):
218    """Remove NaN values from array."""
219    a = np.asarray(a)
220    return a[~np.isnan(a)]
221
222
223def remove_dim(a, dim):
224    """Remove all values corresponding to dimension `dim` from an array.
225
226    Note that a `dim` of 0 refers to the “x” values.
227
228    Examples:
229
230    >>> a = [1, 2, 3]
231    >>> remove_dim(a, 0)
232    array([2, 3])
233    >>> remove_dim(a, 1)
234    array([1, 3])
235    >>> remove_dim(a, 2)
236    array([1, 2])
237
238    >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
239    >>> remove_dim(a, 0)
240    array([[5, 6],
241           [8, 9]])
242    >>> remove_dim(a, 1)
243    array([[1, 3],
244           [7, 9]])
245    >>> remove_dim(a, 2)
246    array([[1, 2],
247           [4, 5]])
248
249    """
250    _a = np.asarray(a)
251    for i, _ in enumerate(_a.shape):
252        _a = np.delete(_a, [dim], axis=i)
253    return _a
254
255
256def add_dim(a, dim, val=0):
257    """Add entries of `val` corresponding to dimension `dim` to an array.
258
259    Note that a `dim` of 0 refers to the “x” values.
260
261    Examples:
262
263    >>> a = [1, 2]
264    >>> add_dim(a, 0)
265    array([0, 1, 2])
266    >>> add_dim(a, 1)
267    array([1, 0, 2])
268    >>> add_dim(a, 2)
269    array([1, 2, 0])
270
271    >>> add_dim([1.0, 2.0], 2)
272    array([1., 2., 0.])
273
274    >>> a = [[1, 2], [3, 4]]
275    >>> add_dim(a, 0)
276    array([[0, 0, 0],
277           [0, 1, 2],
278           [0, 3, 4]])
279    >>> add_dim(a, 1)
280    array([[1, 0, 2],
281           [0, 0, 0],
282           [3, 0, 4]])
283    >>> add_dim(a, 2)
284    array([[1, 2, 0],
285           [3, 4, 0],
286           [0, 0, 0]])
287
288    """
289    _a = np.asarray(a)
290    for i, _ in enumerate(_a.shape):
291        _a = np.insert(_a, [dim], 0, axis=i)
292    return _a
293
294
295def default_ncpus() -> int:
296    """Get a safe default number of CPUs available for multiprocessing.
297
298    On Linux platforms that support it, the method `os.sched_getaffinity()` is used.
299    On Mac OS, the command `sysctl -n hw.ncpu` is used.
300    On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried.
301    If any of these fail, a fallback of 1 is used and a warning is logged.
302
303    """
304    try:
305        match platform.system():
306            case "Linux":
307                return len(os.sched_getaffinity(0)) - 1  # May raise AttributeError.
308            case "Darwin":
309                # May raise CalledProcessError.
310                out = subprocess.run(
311                    ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True
312                )
313                return int(out.stdout.strip()) - 1
314            case "Windows":
315                return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1
316            case _:
317                return 1
318    except (AttributeError, subprocess.CalledProcessError, KeyError):
319        return 1
320
321
322def diff_like(a):
323    """Get forward difference of 2D array `a`, with repeated last elements.
324
325    The repeated last elements ensure that output and input arrays have equal shape.
326
327    Examples:
328
329    >>> diff_like(np.array([1, 2, 3, 4, 5]))
330    array([[1, 1, 1, 1, 1]])
331
332    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
333    array([[1, 1, 1, 1, 1],
334           [2, 3, 3, 1, 1]])
335
336    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
337    array([[ 1.,  1.,  1.,  1.,  1.],
338           [ 2.,  3.,  3.,  1.,  1.],
339           [-1.,  0.,  0., inf, nan]])
340
341    """
342    a2 = np.atleast_2d(a)
343    return np.diff(
344        a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1))
345    )
346
347
348def angle_fse_simpleshear(strain):
349    """Get angle of FSE long axis anticlockwise from the X axis in simple shear."""
350    return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))
351
352
353def lag_2d_corner_flow(θ):
354    """Get predicted grain orientation lag for 2D corner flow.
355
356    See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222).
357
358    """
359     = np.ma.masked_less(θ, 1e-15)
360    return ( * (**2 + np.cos() ** 2)) / (
361        np.tan() * (**2 + np.cos() ** 2 -  * np.sin(2 * ))
362    )
363
364
365@nb.njit(fastmath=True)
366def quat_product(q1, q2):
367    """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format."""
368    return [
369        *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]),
370        q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]),
371    ]
372
373
374def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
375    """Redraw legend on matplotlib axis or figure.
376
377    Transparency is removed from legend symbols.
378    If `fig` is not None and `remove_all` is True,
379    all legends are first removed from the parent figure.
380    Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default,
381    or `matplotlib.figure.Figure.legend` if `fig` is not None.
382
383    If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes
384    instead of taking up space in the original axes. This option requires `fig=None`.
385
386    .. warning::
387        Note that if `fig` is not `None`, the legend may be cropped from the saved
388        figure due to a Matplotlib bug. In this case, it is required to add the
389        arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`,
390        where `legend` is the object returned by this function. To prevent the legend
391        from consuming axes/subplot space, it is further required to add the lines:
392        `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)`
393        and `fig.set_layout_engine("none")` before saving the figure.
394
395    """
396    handler_map = {
397        PathCollection: HandlerPathCollection(
398            update_func=_remove_legend_symbol_transparency
399        ),
400        Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency),
401    }
402    if fig is None:
403        legend = ax.get_legend()
404        if legend is not None:
405            handles, labels = ax.get_legend_handles_labels()
406            legend.remove()
407        if legendax is not None:
408            legendax.axis("off")
409            return legendax.legend(handles, labels, handler_map=handler_map, **kwargs)
410        return ax.legend(handler_map=handler_map, **kwargs)
411    else:
412        if legendax is not None:
413            _log.warning("ignoring `legendax` argument which requires `fig=None`")
414        for legend in fig.legends:
415            if legend is not None:
416                legend.remove()
417        if remove_all:
418            for ax in fig.axes:
419                legend = ax.get_legend()
420                if legend is not None:
421                    legend.remove()
422        return fig.legend(handler_map=handler_map, **kwargs)
423
424
425def add_subplot_labels(
426    mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs
427):
428    """Add subplot labels to axes mosaic.
429
430    Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels.
431    If `labelmap` is None, the keys in `axs` will be used as the labels by default.
432
433    If `internal` is `False` (default), the axes titles will be used.
434    Otherwise, internal labels will be drawn with `ax.text`,
435    in which case `loc` must be a tuple of floats.
436
437    Any axes in `axs` corresponding to the special key `legend` are skipped.
438
439    """
440    for txt, ax in mosaic.items():
441        if txt.lower() == "legend":
442            continue
443        _txt = labelmap[txt] if labelmap is not None else txt
444        if internal:
445            trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans)
446            if isinstance(loc, str):
447                raise ValueError(
448                    "'loc' argument must be a sequence of float when 'internal' is 'True'"
449                )
450            ax.text(
451                *loc,
452                _txt,
453                transform=ax.transAxes + trans,
454                fontsize=fontsize,
455                bbox={
456                    "facecolor": (1.0, 1.0, 1.0, 0.3),
457                    "edgecolor": "none",
458                    "pad": 3.0,
459                },
460            )
461        else:
462            ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)
463
464
465def _remove_legend_symbol_transparency(handle, orig):
466    """Remove transparency from symbols used in a Matplotlib legend."""
467    # https://stackoverflow.com/a/59629242/12519962
468    handle.update_from(orig)
469    handle.set_alpha(1)
def import_proc_pool() -> tuple:
22def import_proc_pool() -> tuple:
23    """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`.
24
25    Import a process `Pool` object either from Ray of from Python's stdlib.
26    Both offer the same API, the Ray implementation will be preferred if available.
27    Using the `Pool` provided by Ray allows for distributed memory multiprocessing.
28
29    Returns a tuple containing the `Pool` object and a boolean flag which is `True` if
30    Ray is available.
31
32    """
33    try:
34        from ray.util.multiprocessing import Pool
35
36        has_ray = True
37    except ImportError:
38        from multiprocessing import Pool
39
40        has_ray = False
41    return Pool, has_ray

Import either ray.util.multiprocessing.Pool or multiprocessing.Pool.

Import a process Pool object either from Ray of from Python's stdlib. Both offer the same API, the Ray implementation will be preferred if available. Using the Pool provided by Ray allows for distributed memory multiprocessing.

Returns a tuple containing the Pool object and a boolean flag which is True if Ray is available.

def in_ci(platform: str) -> bool:
44def in_ci(platform: str) -> bool:
45    """Check if we are in a GitHub runner with the given operating system."""
46    # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
47    return sys.platform == platform and os.getenv("CI") is not None

Check if we are in a GitHub runner with the given operating system.

class SerializedCallable:
50class SerializedCallable:
51    """A serialized version of the callable f.
52
53    Serialization is performed using the dill library. The object is safe to pass into
54    `multiprocessing.Pool.map` and its alternatives.
55
56    .. note:: To serialize a lexical closure (i.e. a function defined inside a
57        function), use the `serializable` decorator.
58
59    """
60
61    def __init__(self, f):
62        self._f = dill.dumps(f, protocol=5, byref=True)
63
64    def __call__(self, *args, **kwargs):
65        return dill.loads(self._f)(*args, **kwargs)

A serialized version of the callable f.

Serialization is performed using the dill library. The object is safe to pass into multiprocessing.Pool.map and its alternatives.

To serialize a lexical closure (i.e. a function defined inside a

function), use the serializable decorator.

SerializedCallable(f)
61    def __init__(self, f):
62        self._f = dill.dumps(f, protocol=5, byref=True)
def serializable(f):
68def serializable(f):
69    """Make decorated function serializable.
70
71    .. warning:: The decorated function cannot be a method, and it will loose its
72        docstring. It is not possible to use `functools.wraps` to mitigate this.
73
74    """
75    return SerializedCallable(f)

Make decorated function serializable.

The decorated function cannot be a method, and it will loose its

docstring. It is not possible to use functools.wraps to mitigate this.

def defined_if(cond):
78def defined_if(cond):
79    """Only define decorated function if `cond` is `True`."""
80
81    def _defined_if(f):
82        def not_f(*args, **kwargs):
83            # Throw the same as we would get from `type(undefined_symbol)`.
84            raise NameError(f"name '{f.__name__}' is not defined")
85
86        @wraps(f)
87        def wrapper(*args, **kwargs):
88            if cond:
89                return f(*args, **kwargs)
90            return not_f(*args, **kwargs)
91
92        return wrapper
93
94    return _defined_if

Only define decorated function if cond is True.

def halfspace( age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-06, fit='Korenaga2016'):
 97def halfspace(
 98    age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-6, fit="Korenaga2016"
 99):
100    r"""Get halfspace cooling temperature based on the chosen fit.
101
102    $$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$
103
104    Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction)
105    are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth
106    $z$ is in metres and age $t$ is in seconds. Supported fits are:
107    - ["Korenaga2016"](http://dx.doi.org/10.1002/2016JB013395)¹, which implements $κ(z)$
108    - "Standard", i.e. $Q = 0$
109
110    ¹Although the fit is found in the 2016 paper, the equation is discussed as a
111    reference model in [Korenaga et al. 2021](https://doi.org/10.1029/2020JB021528).
112    The thermal diffusivity below 7km depth is hardcoded to 3.47e-7.
113
114    """
115    match fit:
116        case "Korenaga2016":
117            a1 = 0.602e-3
118            a2 = -6.045e-10
119            adiabatic = a1 * z + a2 * z**2
120            if z < 7:
121                κ = 3.45e-7
122            else:
123                b0 = -1.255
124                b1 = 9.944
125                b2 = -25.0619
126                b3 = 32.2944
127                b4 = -22.2017
128                b5 = 7.7336
129                b6 = -1.0622
130                coeffs = (b0, b1, b2, b3, b4, b5, b6)
131                z_ref = 1e5
132                κ_0 = diffusivity
133                κ = κ_0 * np.sum(
134                    [b * (z / z_ref) ** (n / 2) for n, b in enumerate(coeffs)]
135                )
136        case "Standard":
137            κ = diffusivity
138            adiabatic = 0.0
139        case _:
140            raise ValueError(f"unsupported fit '{fit}'")
141    return surface_temp + diff_temp * sp.erf(z / (2 * np.sqrt(κ * age))) + adiabatic

Get halfspace cooling temperature based on the chosen fit.

$$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$

Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction) are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth $z$ is in metres and age $t$ is in seconds. Supported fits are:

¹Although the fit is found in the 2016 paper, the equation is discussed as a reference model in Korenaga et al. 2021. The thermal diffusivity below 7km depth is hardcoded to 3.47e-7.

@nb.njit(fastmath=True)
def strain_increment(dt, velocity_gradient):
144@nb.njit(fastmath=True)
145def strain_increment(dt, velocity_gradient):
146    """Calculate strain increment for a given time increment and velocity gradient.
147
148    Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the
149    “(engineering) shear strain” increment.
150
151    """
152    return (
153        np.abs(dt)
154        * np.abs(
155            np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2)
156        ).max()
157    )

Calculate strain increment for a given time increment and velocity gradient.

Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the “(engineering) shear strain” increment.

@nb.njit
def apply_gbs( orientations, fractions, gbs_threshold, orientations_prev, n_grains) -> tuple[numpy.ndarray, numpy.ndarray]:
160@nb.njit
161def apply_gbs(
162    orientations, fractions, gbs_threshold, orientations_prev, n_grains
163) -> tuple[np.ndarray, np.ndarray]:
164    """Apply grain boundary sliding for small grains."""
165    mask = fractions < (gbs_threshold / n_grains)
166    # _log.debug(
167    #     "grain boundary sliding activity (volume percentage): %s",
168    #     len(np.nonzero(mask)) / len(fractions),
169    # )
170    # No rotation: carry over previous orientations.
171    orientations[mask, :, :] = orientations_prev[mask, :, :]
172    fractions[mask] = gbs_threshold / n_grains
173    fractions /= fractions.sum()
174    # _log.debug(
175    #     "grain volume fractions: median=%e, min=%e, max=%e, sum=%e",
176    #     np.median(fractions),
177    #     np.min(fractions),
178    #     np.max(fractions),
179    #     np.sum(fractions),
180    # )
181    return orientations, fractions

Apply grain boundary sliding for small grains.

@nb.njit
def extract_vars(y, n_grains) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
184@nb.njit
185def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
186    """Extract deformation gradient, orientation matrices and grain sizes from y."""
187    deformation_gradient = y[:9].reshape((3, 3))
188    orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1)
189    fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None)
190    fractions /= fractions.sum()
191    return deformation_gradient, orientations, fractions

Extract deformation gradient, orientation matrices and grain sizes from y.

def pad_with(a, x=nan):
194def pad_with(a, x=np.nan):
195    """Pad a list of arrays with `x` and return as a new 2D array with regular shape.
196
197    >>> pad_with([[1, 2, 3], [4, 5], [6]])
198    array([[ 1.,  2.,  3.],
199           [ 4.,  5., nan],
200           [ 6., nan, nan]])
201    >>> pad_with([[1, 2, 3], [4, 5], [6]], x=0)
202    array([[1, 2, 3],
203           [4, 5, 0],
204           [6, 0, 0]])
205    >>> pad_with([[1, 2, 3]])
206    array([[1., 2., 3.]])
207    >>> pad_with([[1, 2, 3]], x=0)
208    array([[1, 2, 3]])
209
210    """
211    longest = max([len(d) for d in a])
212    out = np.full((len(a), longest), x)
213    for i, d in enumerate(a):
214        out[i, : len(d)] = d
215    return out

Pad a list of arrays with x and return as a new 2D array with regular shape.

>>> pad_with([[1, 2, 3], [4, 5], [6]])
array([[ 1.,  2.,  3.],
       [ 4.,  5., nan],
       [ 6., nan, nan]])
>>> pad_with([[1, 2, 3], [4, 5], [6]], x=0)
array([[1, 2, 3],
       [4, 5, 0],
       [6, 0, 0]])
>>> pad_with([[1, 2, 3]])
array([[1., 2., 3.]])
>>> pad_with([[1, 2, 3]], x=0)
array([[1, 2, 3]])
def remove_nans(a):
218def remove_nans(a):
219    """Remove NaN values from array."""
220    a = np.asarray(a)
221    return a[~np.isnan(a)]

Remove NaN values from array.

def remove_dim(a, dim):
224def remove_dim(a, dim):
225    """Remove all values corresponding to dimension `dim` from an array.
226
227    Note that a `dim` of 0 refers to the “x” values.
228
229    Examples:
230
231    >>> a = [1, 2, 3]
232    >>> remove_dim(a, 0)
233    array([2, 3])
234    >>> remove_dim(a, 1)
235    array([1, 3])
236    >>> remove_dim(a, 2)
237    array([1, 2])
238
239    >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
240    >>> remove_dim(a, 0)
241    array([[5, 6],
242           [8, 9]])
243    >>> remove_dim(a, 1)
244    array([[1, 3],
245           [7, 9]])
246    >>> remove_dim(a, 2)
247    array([[1, 2],
248           [4, 5]])
249
250    """
251    _a = np.asarray(a)
252    for i, _ in enumerate(_a.shape):
253        _a = np.delete(_a, [dim], axis=i)
254    return _a

Remove all values corresponding to dimension dim from an array.

Note that a dim of 0 refers to the “x” values.

Examples:

>>> a = [1, 2, 3]
>>> remove_dim(a, 0)
array([2, 3])
>>> remove_dim(a, 1)
array([1, 3])
>>> remove_dim(a, 2)
array([1, 2])
>>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> remove_dim(a, 0)
array([[5, 6],
       [8, 9]])
>>> remove_dim(a, 1)
array([[1, 3],
       [7, 9]])
>>> remove_dim(a, 2)
array([[1, 2],
       [4, 5]])
def add_dim(a, dim, val=0):
257def add_dim(a, dim, val=0):
258    """Add entries of `val` corresponding to dimension `dim` to an array.
259
260    Note that a `dim` of 0 refers to the “x” values.
261
262    Examples:
263
264    >>> a = [1, 2]
265    >>> add_dim(a, 0)
266    array([0, 1, 2])
267    >>> add_dim(a, 1)
268    array([1, 0, 2])
269    >>> add_dim(a, 2)
270    array([1, 2, 0])
271
272    >>> add_dim([1.0, 2.0], 2)
273    array([1., 2., 0.])
274
275    >>> a = [[1, 2], [3, 4]]
276    >>> add_dim(a, 0)
277    array([[0, 0, 0],
278           [0, 1, 2],
279           [0, 3, 4]])
280    >>> add_dim(a, 1)
281    array([[1, 0, 2],
282           [0, 0, 0],
283           [3, 0, 4]])
284    >>> add_dim(a, 2)
285    array([[1, 2, 0],
286           [3, 4, 0],
287           [0, 0, 0]])
288
289    """
290    _a = np.asarray(a)
291    for i, _ in enumerate(_a.shape):
292        _a = np.insert(_a, [dim], 0, axis=i)
293    return _a

Add entries of val corresponding to dimension dim to an array.

Note that a dim of 0 refers to the “x” values.

Examples:

>>> a = [1, 2]
>>> add_dim(a, 0)
array([0, 1, 2])
>>> add_dim(a, 1)
array([1, 0, 2])
>>> add_dim(a, 2)
array([1, 2, 0])
>>> add_dim([1.0, 2.0], 2)
array([1., 2., 0.])
>>> a = [[1, 2], [3, 4]]
>>> add_dim(a, 0)
array([[0, 0, 0],
       [0, 1, 2],
       [0, 3, 4]])
>>> add_dim(a, 1)
array([[1, 0, 2],
       [0, 0, 0],
       [3, 0, 4]])
>>> add_dim(a, 2)
array([[1, 2, 0],
       [3, 4, 0],
       [0, 0, 0]])
def default_ncpus() -> int:
296def default_ncpus() -> int:
297    """Get a safe default number of CPUs available for multiprocessing.
298
299    On Linux platforms that support it, the method `os.sched_getaffinity()` is used.
300    On Mac OS, the command `sysctl -n hw.ncpu` is used.
301    On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried.
302    If any of these fail, a fallback of 1 is used and a warning is logged.
303
304    """
305    try:
306        match platform.system():
307            case "Linux":
308                return len(os.sched_getaffinity(0)) - 1  # May raise AttributeError.
309            case "Darwin":
310                # May raise CalledProcessError.
311                out = subprocess.run(
312                    ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True
313                )
314                return int(out.stdout.strip()) - 1
315            case "Windows":
316                return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1
317            case _:
318                return 1
319    except (AttributeError, subprocess.CalledProcessError, KeyError):
320        return 1

Get a safe default number of CPUs available for multiprocessing.

On Linux platforms that support it, the method os.sched_getaffinity() is used. On Mac OS, the command sysctl -n hw.ncpu is used. On Windows, the environment variable NUMBER_OF_PROCESSORS is queried. If any of these fail, a fallback of 1 is used and a warning is logged.

def diff_like(a):
323def diff_like(a):
324    """Get forward difference of 2D array `a`, with repeated last elements.
325
326    The repeated last elements ensure that output and input arrays have equal shape.
327
328    Examples:
329
330    >>> diff_like(np.array([1, 2, 3, 4, 5]))
331    array([[1, 1, 1, 1, 1]])
332
333    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
334    array([[1, 1, 1, 1, 1],
335           [2, 3, 3, 1, 1]])
336
337    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
338    array([[ 1.,  1.,  1.,  1.,  1.],
339           [ 2.,  3.,  3.,  1.,  1.],
340           [-1.,  0.,  0., inf, nan]])
341
342    """
343    a2 = np.atleast_2d(a)
344    return np.diff(
345        a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1))
346    )

Get forward difference of 2D array a, with repeated last elements.

The repeated last elements ensure that output and input arrays have equal shape.

Examples:

>>> diff_like(np.array([1, 2, 3, 4, 5]))
array([[1, 1, 1, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
array([[1, 1, 1, 1, 1],
       [2, 3, 3, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
array([[ 1.,  1.,  1.,  1.,  1.],
       [ 2.,  3.,  3.,  1.,  1.],
       [-1.,  0.,  0., inf, nan]])
def angle_fse_simpleshear(strain):
349def angle_fse_simpleshear(strain):
350    """Get angle of FSE long axis anticlockwise from the X axis in simple shear."""
351    return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))

Get angle of FSE long axis anticlockwise from the X axis in simple shear.

def lag_2d_corner_flow(θ):
354def lag_2d_corner_flow(θ):
355    """Get predicted grain orientation lag for 2D corner flow.
356
357    See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222).
358
359    """
360     = np.ma.masked_less(θ, 1e-15)
361    return ( * (**2 + np.cos() ** 2)) / (
362        np.tan() * (**2 + np.cos() ** 2 -  * np.sin(2 * ))
363    )

Get predicted grain orientation lag for 2D corner flow.

See eq. 11 in Kaminski & Ribe (2002).

@nb.njit(fastmath=True)
def quat_product(q1, q2):
366@nb.njit(fastmath=True)
367def quat_product(q1, q2):
368    """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format."""
369    return [
370        *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]),
371        q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]),
372    ]

Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.

def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
375def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
376    """Redraw legend on matplotlib axis or figure.
377
378    Transparency is removed from legend symbols.
379    If `fig` is not None and `remove_all` is True,
380    all legends are first removed from the parent figure.
381    Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default,
382    or `matplotlib.figure.Figure.legend` if `fig` is not None.
383
384    If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes
385    instead of taking up space in the original axes. This option requires `fig=None`.
386
387    .. warning::
388        Note that if `fig` is not `None`, the legend may be cropped from the saved
389        figure due to a Matplotlib bug. In this case, it is required to add the
390        arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`,
391        where `legend` is the object returned by this function. To prevent the legend
392        from consuming axes/subplot space, it is further required to add the lines:
393        `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)`
394        and `fig.set_layout_engine("none")` before saving the figure.
395
396    """
397    handler_map = {
398        PathCollection: HandlerPathCollection(
399            update_func=_remove_legend_symbol_transparency
400        ),
401        Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency),
402    }
403    if fig is None:
404        legend = ax.get_legend()
405        if legend is not None:
406            handles, labels = ax.get_legend_handles_labels()
407            legend.remove()
408        if legendax is not None:
409            legendax.axis("off")
410            return legendax.legend(handles, labels, handler_map=handler_map, **kwargs)
411        return ax.legend(handler_map=handler_map, **kwargs)
412    else:
413        if legendax is not None:
414            _log.warning("ignoring `legendax` argument which requires `fig=None`")
415        for legend in fig.legends:
416            if legend is not None:
417                legend.remove()
418        if remove_all:
419            for ax in fig.axes:
420                legend = ax.get_legend()
421                if legend is not None:
422                    legend.remove()
423        return fig.legend(handler_map=handler_map, **kwargs)

Redraw legend on matplotlib axis or figure.

Transparency is removed from legend symbols. If fig is not None and remove_all is True, all legends are first removed from the parent figure. Optional keyword arguments are passed to matplotlib.axes.Axes.legend by default, or matplotlib.figure.Figure.legend if fig is not None.

If legendax is not None, the axis legend will be redrawn using the legendax axes instead of taking up space in the original axes. This option requires fig=None.

Note that if fig is not None, the legend may be cropped from the saved figure due to a Matplotlib bug. In this case, it is required to add the arguments bbox_extra_artists=(legend,) and bbox_inches="tight" to savefig, where legend is the object returned by this function. To prevent the legend from consuming axes/subplot space, it is further required to add the lines: legend.set_in_layout(False), fig.canvas.draw(), legend.set_layout(True) and fig.set_layout_engine("none") before saving the figure.

def add_subplot_labels( mosaic, labelmap=None, loc='left', fontsize='medium', internal=False, **kwargs):
426def add_subplot_labels(
427    mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs
428):
429    """Add subplot labels to axes mosaic.
430
431    Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels.
432    If `labelmap` is None, the keys in `axs` will be used as the labels by default.
433
434    If `internal` is `False` (default), the axes titles will be used.
435    Otherwise, internal labels will be drawn with `ax.text`,
436    in which case `loc` must be a tuple of floats.
437
438    Any axes in `axs` corresponding to the special key `legend` are skipped.
439
440    """
441    for txt, ax in mosaic.items():
442        if txt.lower() == "legend":
443            continue
444        _txt = labelmap[txt] if labelmap is not None else txt
445        if internal:
446            trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans)
447            if isinstance(loc, str):
448                raise ValueError(
449                    "'loc' argument must be a sequence of float when 'internal' is 'True'"
450                )
451            ax.text(
452                *loc,
453                _txt,
454                transform=ax.transAxes + trans,
455                fontsize=fontsize,
456                bbox={
457                    "facecolor": (1.0, 1.0, 1.0, 0.3),
458                    "edgecolor": "none",
459                    "pad": 3.0,
460                },
461            )
462        else:
463            ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)

Add subplot labels to axes mosaic.

Use labelmap to specify a dictionary that maps keys in mosaic to subplot labels. If labelmap is None, the keys in axs will be used as the labels by default.

If internal is False (default), the axes titles will be used. Otherwise, internal labels will be drawn with ax.text, in which case loc must be a tuple of floats.

Any axes in axs corresponding to the special key legend are skipped.