pydrex.utils

PyDRex: Miscellaneous utility methods.

  1"""> PyDRex: Miscellaneous utility methods."""
  2
  3import os
  4import platform
  5import subprocess
  6import sys
  7from functools import wraps
  8
  9import dill
 10import numba as nb
 11import numpy as np
 12from matplotlib.collections import PathCollection
 13from matplotlib.legend_handler import HandlerLine2D, HandlerPathCollection
 14from matplotlib.pyplot import Line2D
 15from matplotlib.transforms import ScaledTranslation
 16
 17from pydrex import logger as _log
 18
 19
 20def import_proc_pool() -> tuple:
 21    """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`.
 22
 23    Import a process `Pool` object either from Ray of from Python's stdlib.
 24    Both offer the same API, the Ray implementation will be preferred if available.
 25    Using the `Pool` provided by Ray allows for distributed memory multiprocessing.
 26
 27    Returns a tuple containing the `Pool` object and a boolean flag which is `True` if
 28    Ray is available.
 29
 30    """
 31    try:
 32        from ray.util.multiprocessing import Pool
 33
 34        has_ray = True
 35    except ImportError:
 36        from multiprocessing import Pool
 37
 38        has_ray = False
 39    return Pool, has_ray
 40
 41
 42def in_ci(platform: str) -> bool:
 43    """Check if we are in a GitHub runner with the given operating system."""
 44    # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
 45    return sys.platform == platform and os.getenv("CI") is not None
 46
 47
 48class SerializedCallable:
 49    """A serialized version of the callable f.
 50
 51    Serialization is performed using the dill library. The object is safe to pass into
 52    `multiprocessing.Pool.map` and its alternatives.
 53
 54    .. note:: To serialize a lexical closure (i.e. a function defined inside a
 55        function), use the `serializable` decorator.
 56
 57    """
 58
 59    def __init__(self, f):
 60        self._f = dill.dumps(f, protocol=5, byref=True)
 61
 62    def __call__(self, *args, **kwargs):
 63        return dill.loads(self._f)(*args, **kwargs)
 64
 65
 66def serializable(f):
 67    """Make decorated function serializable.
 68
 69    .. warning:: The decorated function cannot be a method, and it will loose its
 70        docstring. It is not possible to use `functools.wraps` to mitigate this.
 71
 72    """
 73    return SerializedCallable(f)
 74
 75
 76def defined_if(cond):
 77    """Only define decorated function if `cond` is `True`."""
 78
 79    def _defined_if(f):
 80        def not_f(*args, **kwargs):
 81            # Throw the same as we would get from `type(undefined_symbol)`.
 82            raise NameError(f"name '{f.__name__}' is not defined")
 83
 84        @wraps(f)
 85        def wrapper(*args, **kwargs):
 86            if cond:
 87                return f(*args, **kwargs)
 88            return not_f(*args, **kwargs)
 89
 90        return wrapper
 91
 92    return _defined_if
 93
 94
 95@nb.njit(fastmath=True)
 96def strain_increment(dt, velocity_gradient):
 97    """Calculate strain increment for a given time increment and velocity gradient.
 98
 99    Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the
100    “(engineering) shear strain” increment.
101
102    """
103    return (
104        np.abs(dt)
105        * np.abs(
106            np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2)
107        ).max()
108    )
109
110
111@nb.njit
112def apply_gbs(
113    orientations, fractions, gbs_threshold, orientations_prev, n_grains
114) -> tuple[np.ndarray, np.ndarray]:
115    """Apply grain boundary sliding for small grains."""
116    mask = fractions < (gbs_threshold / n_grains)
117    # _log.debug(
118    #     "grain boundary sliding activity (volume percentage): %s",
119    #     len(np.nonzero(mask)) / len(fractions),
120    # )
121    # No rotation: carry over previous orientations.
122    orientations[mask, :, :] = orientations_prev[mask, :, :]
123    fractions[mask] = gbs_threshold / n_grains
124    fractions /= fractions.sum()
125    # _log.debug(
126    #     "grain volume fractions: median=%e, min=%e, max=%e, sum=%e",
127    #     np.median(fractions),
128    #     np.min(fractions),
129    #     np.max(fractions),
130    #     np.sum(fractions),
131    # )
132    return orientations, fractions
133
134
135@nb.njit
136def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
137    """Extract deformation gradient, orientation matrices and grain sizes from y."""
138    deformation_gradient = y[:9].reshape((3, 3))
139    orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1)
140    fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None)
141    fractions /= fractions.sum()
142    return deformation_gradient, orientations, fractions
143
144
145def remove_nans(a):
146    """Remove NaN values from array."""
147    a = np.asarray(a)
148    return a[~np.isnan(a)]
149
150
151def remove_dim(a, dim):
152    """Remove all values corresponding to dimension `dim` from an array.
153
154    Note that a `dim` of 0 refers to the “x” values.
155
156    Examples:
157
158    >>> a = [1, 2, 3]
159    >>> remove_dim(a, 0)
160    array([2, 3])
161    >>> remove_dim(a, 1)
162    array([1, 3])
163    >>> remove_dim(a, 2)
164    array([1, 2])
165
166    >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
167    >>> remove_dim(a, 0)
168    array([[5, 6],
169           [8, 9]])
170    >>> remove_dim(a, 1)
171    array([[1, 3],
172           [7, 9]])
173    >>> remove_dim(a, 2)
174    array([[1, 2],
175           [4, 5]])
176
177    """
178    _a = np.asarray(a)
179    for i, _ in enumerate(_a.shape):
180        _a = np.delete(_a, [dim], axis=i)
181    return _a
182
183
184def add_dim(a, dim, val=0):
185    """Add entries of `val` corresponding to dimension `dim` to an array.
186
187    Note that a `dim` of 0 refers to the “x” values.
188
189    Examples:
190
191    >>> a = [1, 2]
192    >>> add_dim(a, 0)
193    array([0, 1, 2])
194    >>> add_dim(a, 1)
195    array([1, 0, 2])
196    >>> add_dim(a, 2)
197    array([1, 2, 0])
198
199    >>> add_dim([1.0, 2.0], 2)
200    array([1., 2., 0.])
201
202    >>> a = [[1, 2], [3, 4]]
203    >>> add_dim(a, 0)
204    array([[0, 0, 0],
205           [0, 1, 2],
206           [0, 3, 4]])
207    >>> add_dim(a, 1)
208    array([[1, 0, 2],
209           [0, 0, 0],
210           [3, 0, 4]])
211    >>> add_dim(a, 2)
212    array([[1, 2, 0],
213           [3, 4, 0],
214           [0, 0, 0]])
215
216    """
217    _a = np.asarray(a)
218    for i, _ in enumerate(_a.shape):
219        _a = np.insert(_a, [dim], 0, axis=i)
220    return _a
221
222
223def default_ncpus() -> int:
224    """Get a safe default number of CPUs available for multiprocessing.
225
226    On Linux platforms that support it, the method `os.sched_getaffinity()` is used.
227    On Mac OS, the command `sysctl -n hw.ncpu` is used.
228    On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried.
229    If any of these fail, a fallback of 1 is used and a warning is logged.
230
231    """
232    try:
233        match platform.system():
234            case "Linux":
235                return len(os.sched_getaffinity(0)) - 1  # May raise AttributeError.
236            case "Darwin":
237                # May raise CalledProcessError.
238                out = subprocess.run(
239                    ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True
240                )
241                return int(out.stdout.strip()) - 1
242            case "Windows":
243                return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1
244            case _:
245                return 1
246    except (AttributeError, subprocess.CalledProcessError, KeyError):
247        return 1
248
249
250def diff_like(a):
251    """Get forward difference of 2D array `a`, with repeated last elements.
252
253    The repeated last elements ensure that output and input arrays have equal shape.
254
255    Examples:
256
257    >>> diff_like(np.array([1, 2, 3, 4, 5]))
258    array([[1, 1, 1, 1, 1]])
259
260    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
261    array([[1, 1, 1, 1, 1],
262           [2, 3, 3, 1, 1]])
263
264    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
265    array([[ 1.,  1.,  1.,  1.,  1.],
266           [ 2.,  3.,  3.,  1.,  1.],
267           [-1.,  0.,  0., inf, nan]])
268
269    """
270    a2 = np.atleast_2d(a)
271    return np.diff(
272        a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1))
273    )
274
275
276def angle_fse_simpleshear(strain):
277    """Get angle of FSE long axis anticlockwise from the X axis in simple shear."""
278    return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))
279
280
281def lag_2d_corner_flow(θ):
282    """Get predicted grain orientation lag for 2D corner flow.
283
284    See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222).
285
286    """
287     = np.ma.masked_less(θ, 1e-15)
288    return ( * (**2 + np.cos() ** 2)) / (
289        np.tan() * (**2 + np.cos() ** 2 -  * np.sin(2 * ))
290    )
291
292
293@nb.njit(fastmath=True)
294def quat_product(q1, q2):
295    """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format."""
296    return [
297        *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]),
298        q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]),
299    ]
300
301
302def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
303    """Redraw legend on matplotlib axis or figure.
304
305    Transparency is removed from legend symbols.
306    If `fig` is not None and `remove_all` is True,
307    all legends are first removed from the parent figure.
308    Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default,
309    or `matplotlib.figure.Figure.legend` if `fig` is not None.
310
311    If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes
312    instead of taking up space in the original axes. This option requires `fig=None`.
313
314    .. warning::
315        Note that if `fig` is not `None`, the legend may be cropped from the saved
316        figure due to a Matplotlib bug. In this case, it is required to add the
317        arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`,
318        where `legend` is the object returned by this function. To prevent the legend
319        from consuming axes/subplot space, it is further required to add the lines:
320        `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)`
321        and `fig.set_layout_engine("none")` before saving the figure.
322
323    """
324    handler_map = {
325        PathCollection: HandlerPathCollection(
326            update_func=_remove_legend_symbol_transparency
327        ),
328        Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency),
329    }
330    if fig is None:
331        legend = ax.get_legend()
332        if legend is not None:
333            handles, labels = ax.get_legend_handles_labels()
334            legend.remove()
335        if legendax is not None:
336            legendax.axis("off")
337            return legendax.legend(handles, labels, handler_map=handler_map, **kwargs)
338        return ax.legend(handler_map=handler_map, **kwargs)
339    else:
340        if legendax is not None:
341            _log.warning("ignoring `legendax` argument which requires `fig=None`")
342        for legend in fig.legends:
343            if legend is not None:
344                legend.remove()
345        if remove_all:
346            for ax in fig.axes:
347                legend = ax.get_legend()
348                if legend is not None:
349                    legend.remove()
350        return fig.legend(handler_map=handler_map, **kwargs)
351
352
353def add_subplot_labels(
354    mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs
355):
356    """Add subplot labels to axes mosaic.
357
358    Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels.
359    If `labelmap` is None, the keys in `axs` will be used as the labels by default.
360
361    If `internal` is `False` (default), the axes titles will be used.
362    Otherwise, internal labels will be drawn with `ax.text`,
363    in which case `loc` must be a tuple of floats.
364
365    Any axes in `axs` corresponding to the special key `legend` are skipped.
366
367    """
368    for txt, ax in mosaic.items():
369        if txt.lower() == "legend":
370            continue
371        _txt = labelmap[txt] if labelmap is not None else txt
372        if internal:
373            trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans)
374            if isinstance(loc, str):
375                raise ValueError(
376                    "'loc' argument must be a sequence of float when 'internal' is 'True'"
377                )
378            ax.text(
379                *loc,
380                _txt,
381                transform=ax.transAxes + trans,
382                fontsize=fontsize,
383                bbox={
384                    "facecolor": (1.0, 1.0, 1.0, 0.3),
385                    "edgecolor": "none",
386                    "pad": 3.0,
387                },
388            )
389        else:
390            ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)
391
392
393def _remove_legend_symbol_transparency(handle, orig):
394    """Remove transparency from symbols used in a Matplotlib legend."""
395    # https://stackoverflow.com/a/59629242/12519962
396    handle.update_from(orig)
397    handle.set_alpha(1)
def import_proc_pool() -> tuple:
21def import_proc_pool() -> tuple:
22    """Import either `ray.util.multiprocessing.Pool` or `multiprocessing.Pool`.
23
24    Import a process `Pool` object either from Ray of from Python's stdlib.
25    Both offer the same API, the Ray implementation will be preferred if available.
26    Using the `Pool` provided by Ray allows for distributed memory multiprocessing.
27
28    Returns a tuple containing the `Pool` object and a boolean flag which is `True` if
29    Ray is available.
30
31    """
32    try:
33        from ray.util.multiprocessing import Pool
34
35        has_ray = True
36    except ImportError:
37        from multiprocessing import Pool
38
39        has_ray = False
40    return Pool, has_ray

Import either ray.util.multiprocessing.Pool or multiprocessing.Pool.

Import a process Pool object either from Ray of from Python's stdlib. Both offer the same API, the Ray implementation will be preferred if available. Using the Pool provided by Ray allows for distributed memory multiprocessing.

Returns a tuple containing the Pool object and a boolean flag which is True if Ray is available.

def in_ci(platform: str) -> bool:
43def in_ci(platform: str) -> bool:
44    """Check if we are in a GitHub runner with the given operating system."""
45    # https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
46    return sys.platform == platform and os.getenv("CI") is not None

Check if we are in a GitHub runner with the given operating system.

class SerializedCallable:
49class SerializedCallable:
50    """A serialized version of the callable f.
51
52    Serialization is performed using the dill library. The object is safe to pass into
53    `multiprocessing.Pool.map` and its alternatives.
54
55    .. note:: To serialize a lexical closure (i.e. a function defined inside a
56        function), use the `serializable` decorator.
57
58    """
59
60    def __init__(self, f):
61        self._f = dill.dumps(f, protocol=5, byref=True)
62
63    def __call__(self, *args, **kwargs):
64        return dill.loads(self._f)(*args, **kwargs)

A serialized version of the callable f.

Serialization is performed using the dill library. The object is safe to pass into multiprocessing.Pool.map and its alternatives.

To serialize a lexical closure (i.e. a function defined inside a

function), use the serializable decorator.

SerializedCallable(f)
60    def __init__(self, f):
61        self._f = dill.dumps(f, protocol=5, byref=True)
def serializable(f):
67def serializable(f):
68    """Make decorated function serializable.
69
70    .. warning:: The decorated function cannot be a method, and it will loose its
71        docstring. It is not possible to use `functools.wraps` to mitigate this.
72
73    """
74    return SerializedCallable(f)

Make decorated function serializable.

The decorated function cannot be a method, and it will loose its

docstring. It is not possible to use functools.wraps to mitigate this.

def defined_if(cond):
77def defined_if(cond):
78    """Only define decorated function if `cond` is `True`."""
79
80    def _defined_if(f):
81        def not_f(*args, **kwargs):
82            # Throw the same as we would get from `type(undefined_symbol)`.
83            raise NameError(f"name '{f.__name__}' is not defined")
84
85        @wraps(f)
86        def wrapper(*args, **kwargs):
87            if cond:
88                return f(*args, **kwargs)
89            return not_f(*args, **kwargs)
90
91        return wrapper
92
93    return _defined_if

Only define decorated function if cond is True.

@nb.njit(fastmath=True)
def strain_increment(dt, velocity_gradient):
 96@nb.njit(fastmath=True)
 97def strain_increment(dt, velocity_gradient):
 98    """Calculate strain increment for a given time increment and velocity gradient.
 99
100    Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the
101    “(engineering) shear strain” increment.
102
103    """
104    return (
105        np.abs(dt)
106        * np.abs(
107            np.linalg.eigvalsh((velocity_gradient + velocity_gradient.transpose()) / 2)
108        ).max()
109    )

Calculate strain increment for a given time increment and velocity gradient.

Returns “tensorial” strain increment ε, which is equal to γ/2 where γ is the “(engineering) shear strain” increment.

@nb.njit
def apply_gbs( orientations, fractions, gbs_threshold, orientations_prev, n_grains) -> tuple[numpy.ndarray, numpy.ndarray]:
112@nb.njit
113def apply_gbs(
114    orientations, fractions, gbs_threshold, orientations_prev, n_grains
115) -> tuple[np.ndarray, np.ndarray]:
116    """Apply grain boundary sliding for small grains."""
117    mask = fractions < (gbs_threshold / n_grains)
118    # _log.debug(
119    #     "grain boundary sliding activity (volume percentage): %s",
120    #     len(np.nonzero(mask)) / len(fractions),
121    # )
122    # No rotation: carry over previous orientations.
123    orientations[mask, :, :] = orientations_prev[mask, :, :]
124    fractions[mask] = gbs_threshold / n_grains
125    fractions /= fractions.sum()
126    # _log.debug(
127    #     "grain volume fractions: median=%e, min=%e, max=%e, sum=%e",
128    #     np.median(fractions),
129    #     np.min(fractions),
130    #     np.max(fractions),
131    #     np.sum(fractions),
132    # )
133    return orientations, fractions

Apply grain boundary sliding for small grains.

@nb.njit
def extract_vars(y, n_grains) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
136@nb.njit
137def extract_vars(y, n_grains) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
138    """Extract deformation gradient, orientation matrices and grain sizes from y."""
139    deformation_gradient = y[:9].reshape((3, 3))
140    orientations = y[9 : n_grains * 9 + 9].reshape((n_grains, 3, 3)).clip(-1, 1)
141    fractions = y[n_grains * 9 + 9 : n_grains * 10 + 9].clip(0, None)
142    fractions /= fractions.sum()
143    return deformation_gradient, orientations, fractions

Extract deformation gradient, orientation matrices and grain sizes from y.

def remove_nans(a):
146def remove_nans(a):
147    """Remove NaN values from array."""
148    a = np.asarray(a)
149    return a[~np.isnan(a)]

Remove NaN values from array.

def remove_dim(a, dim):
152def remove_dim(a, dim):
153    """Remove all values corresponding to dimension `dim` from an array.
154
155    Note that a `dim` of 0 refers to the “x” values.
156
157    Examples:
158
159    >>> a = [1, 2, 3]
160    >>> remove_dim(a, 0)
161    array([2, 3])
162    >>> remove_dim(a, 1)
163    array([1, 3])
164    >>> remove_dim(a, 2)
165    array([1, 2])
166
167    >>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
168    >>> remove_dim(a, 0)
169    array([[5, 6],
170           [8, 9]])
171    >>> remove_dim(a, 1)
172    array([[1, 3],
173           [7, 9]])
174    >>> remove_dim(a, 2)
175    array([[1, 2],
176           [4, 5]])
177
178    """
179    _a = np.asarray(a)
180    for i, _ in enumerate(_a.shape):
181        _a = np.delete(_a, [dim], axis=i)
182    return _a

Remove all values corresponding to dimension dim from an array.

Note that a dim of 0 refers to the “x” values.

Examples:

>>> a = [1, 2, 3]
>>> remove_dim(a, 0)
array([2, 3])
>>> remove_dim(a, 1)
array([1, 3])
>>> remove_dim(a, 2)
array([1, 2])
>>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> remove_dim(a, 0)
array([[5, 6],
       [8, 9]])
>>> remove_dim(a, 1)
array([[1, 3],
       [7, 9]])
>>> remove_dim(a, 2)
array([[1, 2],
       [4, 5]])
def add_dim(a, dim, val=0):
185def add_dim(a, dim, val=0):
186    """Add entries of `val` corresponding to dimension `dim` to an array.
187
188    Note that a `dim` of 0 refers to the “x” values.
189
190    Examples:
191
192    >>> a = [1, 2]
193    >>> add_dim(a, 0)
194    array([0, 1, 2])
195    >>> add_dim(a, 1)
196    array([1, 0, 2])
197    >>> add_dim(a, 2)
198    array([1, 2, 0])
199
200    >>> add_dim([1.0, 2.0], 2)
201    array([1., 2., 0.])
202
203    >>> a = [[1, 2], [3, 4]]
204    >>> add_dim(a, 0)
205    array([[0, 0, 0],
206           [0, 1, 2],
207           [0, 3, 4]])
208    >>> add_dim(a, 1)
209    array([[1, 0, 2],
210           [0, 0, 0],
211           [3, 0, 4]])
212    >>> add_dim(a, 2)
213    array([[1, 2, 0],
214           [3, 4, 0],
215           [0, 0, 0]])
216
217    """
218    _a = np.asarray(a)
219    for i, _ in enumerate(_a.shape):
220        _a = np.insert(_a, [dim], 0, axis=i)
221    return _a

Add entries of val corresponding to dimension dim to an array.

Note that a dim of 0 refers to the “x” values.

Examples:

>>> a = [1, 2]
>>> add_dim(a, 0)
array([0, 1, 2])
>>> add_dim(a, 1)
array([1, 0, 2])
>>> add_dim(a, 2)
array([1, 2, 0])
>>> add_dim([1.0, 2.0], 2)
array([1., 2., 0.])
>>> a = [[1, 2], [3, 4]]
>>> add_dim(a, 0)
array([[0, 0, 0],
       [0, 1, 2],
       [0, 3, 4]])
>>> add_dim(a, 1)
array([[1, 0, 2],
       [0, 0, 0],
       [3, 0, 4]])
>>> add_dim(a, 2)
array([[1, 2, 0],
       [3, 4, 0],
       [0, 0, 0]])
def default_ncpus() -> int:
224def default_ncpus() -> int:
225    """Get a safe default number of CPUs available for multiprocessing.
226
227    On Linux platforms that support it, the method `os.sched_getaffinity()` is used.
228    On Mac OS, the command `sysctl -n hw.ncpu` is used.
229    On Windows, the environment variable `NUMBER_OF_PROCESSORS` is queried.
230    If any of these fail, a fallback of 1 is used and a warning is logged.
231
232    """
233    try:
234        match platform.system():
235            case "Linux":
236                return len(os.sched_getaffinity(0)) - 1  # May raise AttributeError.
237            case "Darwin":
238                # May raise CalledProcessError.
239                out = subprocess.run(
240                    ["sysctl", "-n", "hw.ncpu"], capture_output=True, check=True
241                )
242                return int(out.stdout.strip()) - 1
243            case "Windows":
244                return int(os.environ["NUMBER_OF_PROCESSORS"]) - 1
245            case _:
246                return 1
247    except (AttributeError, subprocess.CalledProcessError, KeyError):
248        return 1

Get a safe default number of CPUs available for multiprocessing.

On Linux platforms that support it, the method os.sched_getaffinity() is used. On Mac OS, the command sysctl -n hw.ncpu is used. On Windows, the environment variable NUMBER_OF_PROCESSORS is queried. If any of these fail, a fallback of 1 is used and a warning is logged.

def diff_like(a):
251def diff_like(a):
252    """Get forward difference of 2D array `a`, with repeated last elements.
253
254    The repeated last elements ensure that output and input arrays have equal shape.
255
256    Examples:
257
258    >>> diff_like(np.array([1, 2, 3, 4, 5]))
259    array([[1, 1, 1, 1, 1]])
260
261    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
262    array([[1, 1, 1, 1, 1],
263           [2, 3, 3, 1, 1]])
264
265    >>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
266    array([[ 1.,  1.,  1.,  1.,  1.],
267           [ 2.,  3.,  3.,  1.,  1.],
268           [-1.,  0.,  0., inf, nan]])
269
270    """
271    a2 = np.atleast_2d(a)
272    return np.diff(
273        a2, append=np.reshape(a2[:, -1] + (a2[:, -1] - a2[:, -2]), (a2.shape[0], 1))
274    )

Get forward difference of 2D array a, with repeated last elements.

The repeated last elements ensure that output and input arrays have equal shape.

Examples:

>>> diff_like(np.array([1, 2, 3, 4, 5]))
array([[1, 1, 1, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10]]))
array([[1, 1, 1, 1, 1],
       [2, 3, 3, 1, 1]])
>>> diff_like(np.array([[1, 2, 3, 4, 5], [1, 3, 6, 9, 10], [1, 0, 0, 0, np.inf]]))
array([[ 1.,  1.,  1.,  1.,  1.],
       [ 2.,  3.,  3.,  1.,  1.],
       [-1.,  0.,  0., inf, nan]])
def angle_fse_simpleshear(strain):
277def angle_fse_simpleshear(strain):
278    """Get angle of FSE long axis anticlockwise from the X axis in simple shear."""
279    return np.rad2deg(np.arctan(np.sqrt(strain**2 + 1) + strain))

Get angle of FSE long axis anticlockwise from the X axis in simple shear.

def lag_2d_corner_flow(θ):
282def lag_2d_corner_flow(θ):
283    """Get predicted grain orientation lag for 2D corner flow.
284
285    See eq. 11 in [Kaminski & Ribe (2002)](https://doi.org/10.1029/2001GC000222).
286
287    """
288     = np.ma.masked_less(θ, 1e-15)
289    return ( * (**2 + np.cos() ** 2)) / (
290        np.tan() * (**2 + np.cos() ** 2 -  * np.sin(2 * ))
291    )

Get predicted grain orientation lag for 2D corner flow.

See eq. 11 in Kaminski & Ribe (2002).

@nb.njit(fastmath=True)
def quat_product(q1, q2):
294@nb.njit(fastmath=True)
295def quat_product(q1, q2):
296    """Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format."""
297    return [
298        *q1[-1] * q2[:3] + q2[-1] * q1[:3] + np.cross(q1[:3], q1[:3]),
299        q1[-1] * q2[-1] - np.dot(q1[:3], q2[:3]),
300    ]

Quaternion product, q1, q2 and output are in scalar-last (x,y,z,w) format.

def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
303def redraw_legend(ax, fig=None, legendax=None, remove_all=True, **kwargs):
304    """Redraw legend on matplotlib axis or figure.
305
306    Transparency is removed from legend symbols.
307    If `fig` is not None and `remove_all` is True,
308    all legends are first removed from the parent figure.
309    Optional keyword arguments are passed to `matplotlib.axes.Axes.legend` by default,
310    or `matplotlib.figure.Figure.legend` if `fig` is not None.
311
312    If `legendax` is not None, the axis legend will be redrawn using the `legendax` axes
313    instead of taking up space in the original axes. This option requires `fig=None`.
314
315    .. warning::
316        Note that if `fig` is not `None`, the legend may be cropped from the saved
317        figure due to a Matplotlib bug. In this case, it is required to add the
318        arguments `bbox_extra_artists=(legend,)` and `bbox_inches="tight"` to `savefig`,
319        where `legend` is the object returned by this function. To prevent the legend
320        from consuming axes/subplot space, it is further required to add the lines:
321        `legend.set_in_layout(False)`, `fig.canvas.draw()`, `legend.set_layout(True)`
322        and `fig.set_layout_engine("none")` before saving the figure.
323
324    """
325    handler_map = {
326        PathCollection: HandlerPathCollection(
327            update_func=_remove_legend_symbol_transparency
328        ),
329        Line2D: HandlerLine2D(update_func=_remove_legend_symbol_transparency),
330    }
331    if fig is None:
332        legend = ax.get_legend()
333        if legend is not None:
334            handles, labels = ax.get_legend_handles_labels()
335            legend.remove()
336        if legendax is not None:
337            legendax.axis("off")
338            return legendax.legend(handles, labels, handler_map=handler_map, **kwargs)
339        return ax.legend(handler_map=handler_map, **kwargs)
340    else:
341        if legendax is not None:
342            _log.warning("ignoring `legendax` argument which requires `fig=None`")
343        for legend in fig.legends:
344            if legend is not None:
345                legend.remove()
346        if remove_all:
347            for ax in fig.axes:
348                legend = ax.get_legend()
349                if legend is not None:
350                    legend.remove()
351        return fig.legend(handler_map=handler_map, **kwargs)

Redraw legend on matplotlib axis or figure.

Transparency is removed from legend symbols. If fig is not None and remove_all is True, all legends are first removed from the parent figure. Optional keyword arguments are passed to matplotlib.axes.Axes.legend by default, or matplotlib.figure.Figure.legend if fig is not None.

If legendax is not None, the axis legend will be redrawn using the legendax axes instead of taking up space in the original axes. This option requires fig=None.

Note that if fig is not None, the legend may be cropped from the saved figure due to a Matplotlib bug. In this case, it is required to add the arguments bbox_extra_artists=(legend,) and bbox_inches="tight" to savefig, where legend is the object returned by this function. To prevent the legend from consuming axes/subplot space, it is further required to add the lines: legend.set_in_layout(False), fig.canvas.draw(), legend.set_layout(True) and fig.set_layout_engine("none") before saving the figure.

def add_subplot_labels( mosaic, labelmap=None, loc='left', fontsize='medium', internal=False, **kwargs):
354def add_subplot_labels(
355    mosaic, labelmap=None, loc="left", fontsize="medium", internal=False, **kwargs
356):
357    """Add subplot labels to axes mosaic.
358
359    Use `labelmap` to specify a dictionary that maps keys in `mosaic` to subplot labels.
360    If `labelmap` is None, the keys in `axs` will be used as the labels by default.
361
362    If `internal` is `False` (default), the axes titles will be used.
363    Otherwise, internal labels will be drawn with `ax.text`,
364    in which case `loc` must be a tuple of floats.
365
366    Any axes in `axs` corresponding to the special key `legend` are skipped.
367
368    """
369    for txt, ax in mosaic.items():
370        if txt.lower() == "legend":
371            continue
372        _txt = labelmap[txt] if labelmap is not None else txt
373        if internal:
374            trans = ScaledTranslation(10 / 72, -5 / 72, ax.figure.dpi_scale_trans)
375            if isinstance(loc, str):
376                raise ValueError(
377                    "'loc' argument must be a sequence of float when 'internal' is 'True'"
378                )
379            ax.text(
380                *loc,
381                _txt,
382                transform=ax.transAxes + trans,
383                fontsize=fontsize,
384                bbox={
385                    "facecolor": (1.0, 1.0, 1.0, 0.3),
386                    "edgecolor": "none",
387                    "pad": 3.0,
388                },
389            )
390        else:
391            ax.set_title(_txt, loc=loc, fontsize=fontsize, **kwargs)

Add subplot labels to axes mosaic.

Use labelmap to specify a dictionary that maps keys in mosaic to subplot labels. If labelmap is None, the keys in axs will be used as the labels by default.

If internal is False (default), the axes titles will be used. Otherwise, internal labels will be drawn with ax.text, in which case loc must be a tuple of floats.

Any axes in axs corresponding to the special key legend are skipped.