
PyDRex: Miscellaneous utility methods.

  1"""> PyDRex: Miscellaneous utility methods."""
  3import os
  4import platform
  5import subprocess
  6import sys
  7from functools import wraps
  9import dill
 10import numba as nb
 11import numpy as np
 12import scipy.special as sp
 13from matplotlib.collections import PathCollection
 14from matplotlib.legend_handler import HandlerLine2D, HandlerPathCollection
 15from matplotlib.pyplot import Line2D
 16from matplotlib.transforms import ScaledTranslation
 18from pydrex import logger as _log
 21def import_proc_pool() -> tuple:
 22    """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`.
 24    Import a process `Pool` object either from Ray of from Python's stdlib.
 25    Both offer the same API, the Ray implementation will be preferred if available.
 26    Using the `Pool` provided by Ray allows for distributed memory multiprocessing.
 28    Returns a tuple containing the `Pool` object and a boolean flag which is `True` if
 29    Ray is available.
 31    """
 32    try:
 33        from ray.util.multiprocessing import Pool
 35        has_ray = True
 36    except ImportError:
 37        from multiprocessing import Pool
 39        has_ray = False
 40    return Pool, has_ray
 43def in_ci(platform: str) -> bool:
 44    """Check if we are in a GitHub runner with the given operating system."""
 45    #
 46    return sys.platform == platform and os.getenv("CI") is not None
 49class SerializedCallable:
 50    """A serialized version of the callable f.
 52    Serialization is performed using the dill library. The object is safe to pass into
 53    `` and its alternatives.
 55    .. note:: To serialize a lexical closure (i.e. a function defined inside a
 56        function), use the `serializable` decorator.
 58    """
 60    def __init__(self, f):
 61        self._f = dill.dumps(f, protocol=5, byref=True)
 63    def __call__(self, *args, **kwargs):
 64        return dill.loads(self._f)(*args, **kwargs)
 67def serializable(f):
 68    """Make decorated function serializable.
 70    .. warning:: The decorated function cannot be a method, and it will loose its
 71        docstring. It is not possible to use `functools.wraps` to mitigate this.
 73    """
 74    return SerializedCallable(f)
 77def defined_if(cond):
 78    """Only define decorated function if `cond` is `True`."""
 80    def _defined_if(f):
 81        def not_f(*args, **kwargs):
 82            # Throw the same as we would get from `type(undefined_symbol)`.
 83            raise NameError(f"name '{f.__name__}' is not defined")
 85        @wraps(f)
 86        def wrapper(*args, **kwargs):
 87            if cond:
 88                return f(*args, **kwargs)
 89            return not_f(*args, **kwargs)
 91        return wrapper
 93    return _defined_if
 96def halfspace(
 97    age, z, surface_temp=273, diff_temp=1350, diffusivity=2.23e-6, fit="Korenaga2016"
 99    r"""Get halfspace cooling temperature based on the chosen fit.
101    $$T₀ + ΔT ⋅ \mathrm{erf}\left(\frac{z}{2 \sqrt{κ t}}\right) + Q$$
103    Temperatures $T₀$ (surface), $ΔT$ (base - surface) and $Q$ (adiabatic correction)
104    are expected to be in Kelvin. The diffusivity $κ$ is expected to be in m²s⁻¹. Depth
105    $z$ is in metres and age $t$ is in seconds. Supported fits are:
106    - ["Korenaga2016"](¹, which implements $κ(z)$
107    - "Standard", i.e. $Q = 0$
109    ¹Although the fit is found in the 2016 paper, the equation is discussed as a
110    reference model in [Korenaga et al. 2021](
111    The thermal diffusivity below 7km depth is hardcoded to 3.47e-7.
113    """
114    match fit:
115        case "Korenaga2016":
116            a1 = 0.602e-3
117            a2 = -6.045e-10
118            adiabatic = a1 * z + a2 * z**2
119            if z < 7:
120                κ = 3.45e-7
121            else:
122                b0 = -1.255
123                b1 = 9.944
124                b2 = -25.0619
125                b3 = 32.2944
126                b4 = -22.2017
127                b5 = 7.7336
128                b6 = -1.0622
129                coeffs = (b0, b1, b2, b3, b4, b5, b6)
130                z_ref = 1e5
131                κ_0 = diffusivity
132                κ = κ_0 * np.sum(
133                    [b * (z / z_ref) ** (n / 2) for n, b in enumerate(coeffs)]
134                )
135        case "Standard":
136            κ = diffusivity
137            adiabatic = 0.0
138        case _:
139            raise ValueError(f"unsupported fit '{fit}'")
140    return surface_temp + diff_temp * sp.erf(z / (2 * np.sqrt(κ * age))) + adiabatic
144def strain_increment(dt, velocity_gradient):
145    """Calculate strain increment for a given time increment and velocity gradient.
147    Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the
148    “(engineering) shear strain” increment.
150    """
151    return (
152        np.abs(dt)
153        * np.abs(
154            np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2)
155        ).max()
156    )
160def apply_gbs(
161    orientations, fractions, gbs_threshold, orientations_prev, n_grains
162) -> tuple[np.ndarray, np.ndarray]:
163    """Apply grain boundary sliding for small grains."""
164    mask = fractions < (gbs_threshold / n_grains)
165    # _log.debug(
166    #     "grain boundary sliding activity (volume percentage): %s",
167    #     len(np.nonzero(mask)) / len(fractions),
168    # )
169    # No rotation: carry over previous orientations.
170    orientations[mask, :, :] = orientations_prev[mask, :, :]
171    fractions[mask] = gbs_threshold / n_grains
172    fractions /= fractions.sum()
173    # _log.debug(
174    #     "grain volume fractions: median=%e, min=%e, max=%e, sum=%e",
175    #     np.median(fractions),
176    #     np.min(fractions),
177    #     np.max(fractions),
178    #     np.sum(fractions),
179    # )
180    return orientations, fractions
184def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
185    """Extract deformation gradient, orientation matrices and grain sizes from y."""
186    deformation_gradient = y[:9].reshape((3, 3))
187    orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1)
188    fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None)
189    fractions /= fractions.sum()
190    return deformation_gradient, orientations, fractions
193def pad_with(a, x=np.nan):
194    """Pad a list of arrays with `x` and return as a new 2D array with regular shape.
196    >>> pad_with([[1, 2, 3], [4, 5], [6]])
197    array([[ 1.,  2.,  3.],
198           [ 4.,  5., nan],
199           [ 6., nan, nan]])
200    >>> pad_with([[1, 2, 3], [4, 5], [6]], x=0)
201    array([[1, 2, 3],
202           [4, 5, 0],
203           [6, 0, 0]])
204    >>> pad_with([[1, 2, 3]])
205    array([[1., 2., 3.]])
206    >>> pad_with([[1, 2, 3]], x=0)
207    array([[1, 2, 3]])
209    """
210    longest = max([len(d) for d in a])
211    out = np.full((len(a), longest), x)
212    for i, d in enumerate(a):
213        out[i, : len(d)] = d
214    return out
217def remove_nans(a):
218    """Remove NaN values from array."""
219    a = np.asarray(a)
220    return a[~np.isnan(a)]
223def remove_dim(a, dim):
224    """Remove all values corresponding to dimension `dim` from an array.
226    Note that a `dim` of 0 refers to the “x” values.
228    Examples:
230    >>> a = [1, 2, 3]
231    >>> remove_dim(a, 0)
232    array([2, 3])
233    >>> remove_dim(a, 1)
234    array([1, 3])
235    >>> remove_dim(a, 2)
236    array([1, 2])
238    >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
239    >>> remove_dim(a, 0)
240    array([[5, 6],
241           [8, 9]])
242    >>> remove_dim(a, 1)
243    array([[1, 3],
244           [7, 9]])
245    >>> remove_dim(a, 2)
246    array([[1, 2],
247           [4, 5]])
249    """
250    _a = np.asarray(a)
251    for i, _ in enumerate(_a.shape):
252        _a = np.delete(_a, [dim], axis=i)
253    return _a
256def add_dim(a, dim, val=0):
257    """Add entries of `val` corresponding to dimension `dim` to an array.
259    Note that a `dim` of 0 refers to the “x” values.
261    Examples:
263    >>> a = [1, 2]
264    >>> add_dim(a, 0)
265    array([0, 1, 2])
266    >>> add_dim(a, 1)
267    array([1, 0, 2])
268    >>> add_dim(a, 2)
269    array([1, 2, 0])
271    >>> add_dim([1.0, 2.0], 2)
272    array([1., 2., 0.])
274    >>> a = [[1, 2], [3, 4]]
275    >>> add_dim(a, 0)
276    array([[0, 0, 0],
277           [0, 1, 2],
278           [0, 3, 4]])
279    >>> add_dim(a, 1)
280    array([[1, 0, 2],
281           [0, 0, 0],
282           [3, 0, 4]])
283    >>> add_dim(a, 2)
284    array([[1, 2, 0],
285           [3, 4, 0],
286           [0, 0, 0]])
288    """
289    _a = np.asarray(a)
290    for i, _ in enumerate(_a.shape):
291        _a = np.insert(_a, [dim], 0, axis=i)
292    return _a
295def default_ncpus() -> int:
296    """Get a safe default number of CPUs available for multiprocessing.
298    On Linux platforms that support it, the method `os.sched_getaffinity()` is used.
299    On Mac OS, the command `sysctl -n hw.ncpu` is used.
300    On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried.
301    If any of these fail, a fallback of 1 is used and a warning is logged.
303    """
304    try:
305        match platform.system():
306            case "Linux":
307                return len(os.sched_getaffinity(0)) - 1  # May raise AttributeError.
308            case "Darwin":
309                # May raise CalledProcessError.
310                out =
311                    ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True
312                )
313                return int(out.stdout.strip()) - 1
314            case "Windows":
315                return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1
316            case _:
317                return 1
318    except (AttributeError, subprocess.CalledProcessError, KeyError):
319        return 1
322def diff_like(a):
323    """Get forward difference of 2D array `a`, with repeated last elements.
325    The repeated last elements ensure that output and input arrays have equal shape.
327    Examples:
329    >>> diff_like(np.array([1, 2, 3, 4, 5]))
330    array([[1, 1, 1, 1, 1]])
332    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
333    array([[1, 1, 1, 1, 1],
334           [2, 3, 3, 1, 1]])
336    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
337    array([[ 1.,  1.,  1.,  1.,  1.],
338           [ 2.,  3.,  3.,  1.,  1.],
339           [-1.,  0.,  0., inf, nan]])
341    """
342    a2 = np.atleast_2d(a)
343    return np.diff(
344        a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1))
345    )
348def angle_fse_simpleshear(strain):
349    """Get angle of FSE long axis anticlockwise from the X axis in simple shear."""
350    return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))
353def lag_2d_corner_flow(θ):
354    """Get predicted grain orientation lag for 2D corner flow.
356    See eq. 11 in [Kaminski & Ribe (2002)](
358    """
359     =θ, 1e-15)
360    return ( * (**2 + np.cos() ** 2)) / (
361        np.tan() * (**2 + np.cos() ** 2 -  * np.sin(2 * ))
362    )
366def quat_product(q1, q2):
367    """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format."""
368    return [
369        *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]),
370        q1[-1] * q2[-1] -[:3], q2[:3]),
371    ]
374def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
375    """Redraw legend on matplotlib axis or figure.
377    Transparency is removed from legend symbols.
378    If `fig` is not None and `remove_all` is True,
379    all legends are first removed from the parent figure.
380    Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default,
381    or `matplotlib.figure.Figure.legend` if `fig` is not None.
383    If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes
384    instead of taking up space in the original axes. This option requires `fig=None`.
386    .. warning::
387        Note that if `fig` is not `None`, the legend may be cropped from the saved
388        figure due to a Matplotlib bug. In this case, it is required to add the
389        arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`,
390        where `legend` is the object returned by this function. To prevent the legend
391        from consuming axes/subplot space, it is further required to add the lines:
392        `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)`
393        and `fig.set_layout_engine("none")` before saving the figure.
395    """
396    handler_map = {
397        PathCollection: HandlerPathCollection(
398            update_func=_remove_legend_symbol_transparency
399        ),
400        Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency),
401    }
402    if fig is None:
403        legend = ax.get_legend()
404        if legend is not None:
405            handles, labels = ax.get_legend_handles_labels()
406            legend.remove()
407        if legendax is not None:
408            legendax.axis("off")
409            return legendax.legend(handles, labels, handler_map=handler_map, **kwargs)
410        return ax.legend(handler_map=handler_map, **kwargs)
411    else:
412        if legendax is not None:
413            _log.warning("ignoring `legendax` argument which requires `fig=None`")
414        for legend in fig.legends:
415            if legend is not None:
416                legend.remove()
417        if remove_all:
418            for ax in fig.axes:
419                legend = ax.get_legend()
420                if legend is not None:
421                    legend.remove()
422        return fig.legend(handler_map=handler_map, **kwargs)
425def add_subplot_labels(
426    mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs
428    """Add subplot labels to axes mosaic.
430    Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels.
431    If `labelmap` is None, the keys in `axs` will be used as the labels by default.
433    If `internal` is `False` (default), the axes titles will be used.
434    Otherwise, internal labels will be drawn with `ax.text`,
435    in which case `loc` must be a tuple of floats.
437    Any axes in `axs` corresponding to the special key `legend` are skipped.
439    """
440    for txt, ax in mosaic.items():
441        if txt.lower() == "legend":
442            continue
443        _txt = labelmap[txt] if labelmap is not None else txt
444        if internal:
445            trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans)
446            if isinstance(loc, str):
447                raise ValueError(
448                    "'loc' argument must be a sequence of float when 'internal' is 'True'"
449                )
450            ax.text(
451                *loc,
452                _txt,
453                transform=ax.transAxes + trans,
454                fontsize=fontsize,
455                bbox={
456                    "facecolor": (1.0, 1.0, 1.0, 0.3),
457                    "edgecolor": "none",
458                    "pad": 3.0,
459                },
460            )
461        else:
462            ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)
465def _remove_legend_symbol_transparency(handle, orig):
466    """Remove transparency from symbols used in a Matplotlib legend."""
467    #
468    handle.update_from(orig)
469    handle.set_alpha(1)
