pydrex.visualisation

PyDRex: Visualisation functions for test outputs and examples.

  1"""> PyDRex: Visualisation functions for test outputs and examples."""
  2
  3import numpy as np
  4from cmcrameri import cm as cmc
  5from matplotlib import projections as mproj
  6from matplotlib import pyplot as plt
  7
  8from pydrex import axes as _axes
  9from pydrex import core as _core
 10from pydrex import geometry as _geo
 11from pydrex import io as _io
 12from pydrex import logger as _log
 13from pydrex import utils as _utils
 14
 15# Use a non-interactive vector-graphics backend by default.
 16plt.rcParams["backend"] = "PDF"
 17# Get default figure size for easy referencing and scaling.
 18DEFAULT_FIG_WIDTH, DEFAULT_FIG_HEIGHT = plt.rcParams["figure.figsize"]
 19plt.rcParams["axes.grid"] = True
 20# Always draw grid behind everything else.
 21plt.rcParams["axes.axisbelow"] = True
 22# Always use constrained layout by default (modern version of tight layout).
 23plt.rcParams["figure.constrained_layout.use"] = True
 24# Use 300 DPI by default, NASA can keep their blurry images.
 25plt.rcParams["figure.dpi"] = 300
 26# Make sure we have the required matplotlib "projections" (really just Axes subclasses).
 27if "pydrex.polefigure" not in mproj.get_projection_names():
 28    _log.warning(
 29        "failed to find pydrex.polefigure projection; it should be registered in %s",
 30        _axes,
 31    )
 32
 33
 34def default_tick_formatter(x, pos):
 35    return f"{x/1e3:.1f}"
 36
 37
 38def polefigures(
 39    orientations,
 40    ref_axes,
 41    i_range,
 42    density=False,
 43    savefile="polefigures.png",
 44    strains=None,
 45    **kwargs,
 46):
 47    """Plot pole figures of a series of (Nx3x3) orientation matrix stacks.
 48
 49    Produces [100], [010] and [001] pole figures for (resampled) orientations.
 50    For the argument specification, check the output of `pydrex-polefigures --help`
 51    on the command line.
 52
 53    """
 54    if len(orientations) != len(i_range):
 55        raise ValueError("mismatched length of 'orientations' and 'i_range'")
 56    if strains is not None and len(strains) != len(i_range):
 57        raise ValueError("mismatched length of 'strains'")
 58    n_orientations = len(orientations)
 59    fig = plt.figure(figsize=(n_orientations, 4), dpi=600)
 60
 61    if len(i_range) == 1:
 62        grid = fig.add_gridspec(3, n_orientations, hspace=0, wspace=0.2)
 63        first_row = 0
 64    else:
 65        grid = fig.add_gridspec(
 66            4, n_orientations, height_ratios=((1, 3, 3, 3)), hspace=0, wspace=0.2
 67        )
 68        fig_strain = fig.add_subfigure(grid[0, :])
 69        first_row = 1
 70        ax_strain = fig_strain.add_subplot(111)
 71
 72        if strains is None:
 73            fig_strain.suptitle(
 74                f"N ⋅ (max strain) / {i_range.stop}", x=0.5, y=0.85, fontsize="small"
 75            )
 76            ax_strain.set_xlim(
 77                (i_range.start - i_range.step / 2, i_range.stop - i_range.step / 2)
 78            )
 79            ax_strain.set_xticks(list(i_range))
 80        else:
 81            fig_strain.suptitle("strain (%)", x=0.5, y=0.85, fontsize="small")
 82            ax_strain.set_xticks(strains)
 83            ax_strain.set_xlim(
 84                (
 85                    strains[0] - strains[1] / 2,
 86                    strains[-1] + strains[1] / 2,
 87                )
 88            )
 89
 90        ax_strain.set_frame_on(False)
 91        ax_strain.grid(False)
 92        ax_strain.yaxis.set_visible(False)
 93        ax_strain.xaxis.set_tick_params(labelsize="x-small", length=0)
 94
 95    fig100 = fig.add_subfigure(
 96        grid[first_row, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
 97    )
 98    fig100.suptitle("[100]", fontsize="small")
 99    fig010 = fig.add_subfigure(
100        grid[first_row + 1, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
101    )
102    fig010.suptitle("[010]", fontsize="small")
103    fig001 = fig.add_subfigure(
104        grid[first_row + 2, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
105    )
106    fig001.suptitle("[001]", fontsize="small")
107    for n, orientations in enumerate(orientations):
108        ax100 = fig100.add_subplot(
109            1, n_orientations, n + 1, projection="pydrex.polefigure"
110        )
111        pf100 = ax100.polefigure(
112            orientations,
113            hkl=[1, 0, 0],
114            ref_axes=ref_axes,
115            density=density,
116            density_kwargs=kwargs,
117        )
118        ax010 = fig010.add_subplot(
119            1, n_orientations, n + 1, projection="pydrex.polefigure"
120        )
121        pf010 = ax010.polefigure(
122            orientations,
123            hkl=[0, 1, 0],
124            ref_axes=ref_axes,
125            density=density,
126            density_kwargs=kwargs,
127        )
128        ax001 = fig001.add_subplot(
129            1, n_orientations, n + 1, projection="pydrex.polefigure"
130        )
131        pf001 = ax001.polefigure(
132            orientations,
133            hkl=[0, 0, 1],
134            ref_axes=ref_axes,
135            density=density,
136            density_kwargs=kwargs,
137        )
138        if density:
139            for ax, pf in zip(
140                (ax100, ax010, ax001), (pf100, pf010, pf001), strict=True
141            ):
142                cbar = fig.colorbar(
143                    pf,
144                    ax=ax,
145                    fraction=0.05,
146                    location="bottom",
147                    orientation="horizontal",
148                )
149                cbar.ax.xaxis.set_tick_params(labelsize="xx-small")
150
151    fig.savefig(_io.resolve_path(savefile))
152
153
154def steady_box2d(
155    ax: plt.Axes | None,
156    velocity: tuple,
157    geometry: tuple,
158    ref_axes: str,
159    cpo: tuple | None,
160    colors,
161    aspect="equal",
162    cmap=cmc.batlow,
163    marker=".",
164    tick_formatter=default_tick_formatter,
165    label_suffix="(km)",
166    **kwargs,
167) -> tuple:
168    """Plot pathlines and steady-state velocity arrows for a 2D box domain.
169
170    If `ax` is None, a new figure and axes are created with `figure_unless`.
171
172    - `velocity` — tuple containing a velocity callable¹ and the 2D resolution of the
173      velocity arrow grid, e.g. [20, 20] for 20x20 arrows over the rectangular domain
174    - `geometry` — tuple containing the array of 3D pathline positions and two 2D
175      coordinates (of the lower-left and upper-right domain corners)
176    - `ref_axes` — two letters from {"x", "y", "z"} used to label the horizontal and
177      vertical axes (these also define the projection for the 3D velocity/position)
178    - `cpo` — tuple containing one array of CPO strengths and one of 3D CPO vectors;
179      alternatively set this to `None` and use `marker` to only plot pathline positions
180    - `colors` — monotonic, increasing values along the pathline (e.g. time or strain)
181    - `aspect` — optional, see `matplotlib.axes.Axes.set_aspect`
182    - `cmap` — optional custom color map for `colors`
183    - `marker` — optional pathline position marker used when `cpo` is `None`
184    - `tick_formatter` — optional custom tick formatter callable
185    - `label_suffix` — optional suffix added to the axes labels
186
187    ¹with signature `f(t, x)` where `t` is not used and `x` is a 3D position vector
188
189    Additional keyword arguments are passed to the `matplotlib.axes.Axes.quiver` call
190    used to plot the velocity vectors.
191
192    Returns the figure handle, the axes handle, the quiver collection (velocities) and
193    the scatter collection (pathline).
194
195    """
196    fig, ax = figure_unless(ax)
197    ax.set_xlabel(f"{ref_axes[0]} {label_suffix}")
198    ax.set_ylabel(f"{ref_axes[1]} {label_suffix}")
199
200    get_velocity, resolution = velocity
201    positions, min_coords, max_coords = geometry
202    x_min, y_min = min_coords
203    x_max, y_max = max_coords
204
205    ax.set_xlim((x_min, x_max))
206    ax.set_ylim((y_min, y_max))
207    ax.set_aspect(aspect)
208    ax.xaxis.set_major_formatter(tick_formatter)
209    ax.yaxis.set_major_formatter(tick_formatter)
210
211    _ref_axes = ref_axes.lower()
212    axes_map = {"x": 0, "y": 1, "z": 2}
213    horizontal = axes_map[_ref_axes[0]]
214    vertical = axes_map[_ref_axes[1]]
215
216    velocities = None
217    if resolution is not None:
218        x_res, y_res = resolution
219        X = np.linspace(x_min, x_max, x_res)
220        Y = np.linspace(y_min, y_max, y_res)
221        X_grid, Y_grid = np.meshgrid(X, Y)
222
223        U = np.zeros_like(X_grid.ravel())
224        V = np.zeros_like(Y_grid.ravel())
225        for i, (x, y) in enumerate(zip(X_grid.ravel(), Y_grid.ravel(), strict=True)):
226            p = np.zeros(3)
227            p[horizontal] = x
228            p[vertical] = y
229            v3d = get_velocity(np.nan, p)
230            U[i] = v3d[horizontal]
231            V[i] = v3d[vertical]
232
233        velocities = ax.quiver(
234            X_grid,
235            Y_grid,
236            U.reshape(X_grid.shape),
237            V.reshape(Y_grid.shape),
238            pivot=kwargs.pop("pivot", "mid"),
239            alpha=kwargs.pop("alpha", 0.25),
240            **kwargs,
241        )
242
243    dummy_dim = ({0, 1, 2} - set(_geo.to_indices2d(*ref_axes))).pop()
244    xi_2D = np.asarray([_utils.remove_dim(p, dummy_dim) for p in positions])
245    qcoll: plt.Quiver | plt.PathCollection
246    if cpo is None:
247        qcoll = ax.scatter(xi_2D[:, 0], xi_2D[:, 1], marker=marker, c=colors, cmap=cmap)
248    else:
249        cpo_strengths, cpo_vectors = cpo
250        cpo_2D = np.asarray(
251            [
252                s * _utils.remove_dim(v, dummy_dim)
253                for s, v in zip(cpo_strengths, cpo_vectors, strict=True)
254            ]
255        )
256        qcoll = ax.quiver(
257            xi_2D[:, 0],
258            xi_2D[:, 1],
259            cpo_2D[:, 0],
260            cpo_2D[:, 1],
261            colors,
262            cmap=cmap,
263            pivot="mid",
264            width=kwargs.pop("width", 3e-3),
265            headaxislength=0,
266            headlength=0,
267            zorder=kwargs.pop("zorder", 10) + 1,  # Always above velocity vectors.
268        )
269    return fig, ax, velocities, qcoll
270
271
272def alignment(
273    ax: plt.Axes | None,
274    strains: np.ndarray,
275    angles: np.ndarray,
276    markers: list[str] | tuple[str],
277    labels: list[str] | tuple[str],
278    err: np.ndarray | None = None,
279    θ_max: int = 90,
280    θ_fse: np.ndarray | None = None,
281    colors: np.ndarray | None = None,
282    cmaps=None,
283    **kwargs,
284) -> tuple:
285    """Plot `angles` (in degrees) versus `strains` on the given axis.
286
287    Alignment angles could be either bingham averages or the a-axis in the hexagonal
288    symmetry projection, measured from e.g. the shear direction. In the first case,
289    they should be calculated from resampled grain orientations. Expects as many
290    `markers` and `labels` as there are data series in `angles`.
291
292    If `ax` is None, a new figure and axes are created with `figure_unless`.
293
294    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
295      a 2D array of multiple strain series
296    - `angles` — Y-values, may be a 2D array of multiple angle series
297    - `markers` — MatPlotLib markers to use for the data series
298    - `labels` — labels to use for the data series
299    - `err` (optional) — standard errors for the `angles`, shapes must match
300    - `θ_max` — maximum angle (°) to show on the plot, should be less than 90
301    - `θ_fse` (optional) — an array of angles from the long axis of the finite strain
302      ellipsoid to the reference direction (e.g. shear direction)
303    - `colors` (optional) — color coordinates for series of angles
304    - `cmaps` (optional) — color maps for `colors`
305
306    If `colors` and `cmaps` are used, then angle values are colored individually within
307    each angle series.
308
309    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
310    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
311
312    Returns a tuple of the figure handle, the axes handle and the set of colors used for
313    the data series plots.
314
315    """
316    _strains = np.atleast_2d(strains)
317    _angles = np.atleast_2d(angles)
318    if err is not None:
319        _angles_err = np.atleast_2d(err)
320    if not np.all(_strains.shape == _angles.shape):
321        # Assume strains are all the same for each series in `angles`, try np.tile().
322        _strains = np.tile(_strains, (len(_angles), 1))
323
324    fig, ax = figure_unless(ax)
325    ax.set_ylabel(r"$\overline{θ}$ ∈ [0, 90]°")
326    ax.set_ylim((0, θ_max))
327    ax.set_xlabel("Strain (ε)")
328    ax.set_xlim((np.min(strains), np.max(strains)))
329    _colors = []
330    for i, (strains, θ_cpo, marker, label) in enumerate(
331        zip(_strains, _angles, markers, labels, strict=True)
332    ):
333        if colors is not None:
334            ax.scatter(
335                strains,
336                θ_cpo,
337                marker=marker,
338                label=label,
339                c=colors[i],
340                cmap=cmaps[i],
341                alpha=kwargs.pop("alpha", 0.6),
342                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
343                **kwargs,
344            )
345            _colors.append(colors[i])
346        else:
347            lines = ax.plot(strains, θ_cpo, marker, alpha=0.6, label=label, **kwargs)
348            _colors.append(lines[0].get_color())
349        if err is not None:
350            ax.fill_between(
351                strains,
352                θ_cpo - _angles_err[i],
353                θ_cpo + _angles_err[i],
354                alpha=0.22,
355                color=_colors[i],
356            )
357
358    if θ_fse is not None:
359        ax.plot(strains, θ_fse, linestyle=(0, (5, 5)), alpha=0.6, label="FSE")
360    if not all(b is None for b in labels):
361        _utils.redraw_legend(ax)
362    return fig, ax, _colors
363
364
365def strengths(
366    ax: plt.Axes | None,
367    strains: np.ndarray,
368    strengths: np.ndarray,
369    ylabel: str,
370    markers: list[str] | tuple[str],
371    labels: list[str] | tuple[str],
372    err: np.ndarray | None = None,
373    cpo_threshold: float | None = None,
374    colors: np.ndarray | None = None,
375    cmaps=None,
376    **kwargs,
377):
378    """Plot CPO `strengths` (e.g. M-indices) versus `strains` on the given axis.
379
380    If `ax` is None, a new figure and axes are created with `figure_unless`.
381
382    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
383      a 2D array of multiple strain series
384    - `strengths` — Y-values, may be a 2D array of multiple strength series
385    - `ylabel` — label for the Y axis, depending on chosen texture strength measure
386    - `markers` — MatPlotLib markers to use for the data series
387    - `labels` — labels to use for the data series
388    - `err` (optional) — standard errors for the `strengths`, shapes must match
389    - `colors` (optional) — color coordinates for series of strengths
390    - `cpo_threshold` (optional) — plot a dashed line at this threshold
391    - `cmaps` — color maps for `colors`
392
393    If `colors` and `cmaps` are used, then strength values are colored individually
394    within each strength series.
395
396    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
397    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
398
399    Returns a tuple of the figure handle, the axes handle and the set of colors used for
400    the data series plots.
401
402    """
403    _strains = np.atleast_2d(strains)
404    _strengths = np.atleast_2d(strengths)
405    if err is not None:
406        _strengths_err = np.atleast_2d(err)
407    if not np.all(_strains.shape == _strengths.shape):
408        # Assume strains are all the same for each series in `strengths`, try np.tile().
409        _strains = np.tile(_strains, (len(_strengths), 1))
410
411    fig, ax = figure_unless(ax)
412    ax.set_ylabel(ylabel)
413    ax.set_xlabel("Strain (ε)")
414    ax.set_xlim((np.min(strains), np.max(strains)))
415
416    if cpo_threshold is not None:
417        ax.axhline(cpo_threshold, color=plt.rcParams["axes.edgecolor"], linestyle="--")
418
419    _colors = []
420    for i, (strains, strength, marker, label) in enumerate(
421        zip(_strains, _strengths, markers, labels, strict=True)
422    ):
423        if colors is not None:
424            ax.scatter(
425                strains,
426                strength,
427                marker=marker,
428                label=label,
429                c=colors[i],
430                cmap=cmaps[i],
431                alpha=kwargs.pop("alpha", 0.6),
432                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
433                **kwargs,
434            )
435            _colors.append(colors[i])
436        else:
437            lines = ax.plot(
438                strains, strength, marker, alpha=0.33, label=label, **kwargs
439            )
440            _colors.append(lines[0].get_color())
441        if err is not None:
442            ax.fill_between(
443                strains,
444                strength - _strengths_err[i],
445                strength + _strengths_err[i],
446                alpha=0.22,
447                color=_colors[i],
448            )
449
450    if not all(b is None for b in labels):
451        _utils.redraw_legend(ax)
452    return fig, ax, _colors
453
454
455def grainsizes(ax, strains, fractions) -> tuple:
456    """Plot grain volume `fractions` versus `strains` on the given axis.
457
458    If `ax` is None, a new figure and axes are created with `figure_unless`.
459
460    """
461    n_grains = len(fractions[0])
462    fig, ax = figure_unless(ax)
463    ax.set_ylabel(r"$\log_{10}(f × N)$")
464    ax.set_xlabel("Strain (ε)")
465    parts = ax.violinplot(
466        [np.log10(f * n_grains) for f in fractions], positions=strains, widths=0.8
467    )
468    for part in parts["bodies"]:
469        part.set_color("black")
470        part.set_alpha(1)
471    parts["cbars"].set_alpha(0)
472    parts["cmins"].set_visible(False)
473    parts["cmaxes"].set_visible(False)
474    # parts["cmaxes"].set_color("red")
475    # parts["cmaxes"].set_alpha(0.5)
476    return fig, ax, parts
477
478
479def show_Skemer2016_ShearStrainAngles(
480    ax, studies, markers, colors, fillstyles, labels, fabric
481) -> tuple:
482    """Show data from `src/pydrex/data/thirdparty/Skemer2016_ShearStrainAngles.scsv`.
483
484    Plot data from the Skemer 2016 datafile on the axis given by `ax`. Select the
485    studies from which to plot the data, which must be a list of strings with exact
486    matches in the `study` column in the datafile.
487    Also filter the data to select only the given `fabric`
488    (see `pydrex.core.MineralFabric`).
489
490    If `ax` is None, a new figure and axes are created with `figure_unless`.
491
492    Returns a tuple containing:
493    - the figure handle
494    - the axes handle
495    - the set of colors used for the data series plots
496    - the Skemer 2016 dataset
497    - the indices used to select data according to the "studies" and "fabric" filters
498
499    """
500    fabric_map = {
501        _core.MineralFabric.olivine_A: "A",
502        _core.MineralFabric.olivine_B: "B",
503        _core.MineralFabric.olivine_C: "C",
504        _core.MineralFabric.olivine_D: "D",
505        _core.MineralFabric.olivine_E: "E",
506    }
507    fig, ax = figure_unless(ax)
508    data_Skemer2016 = _io.read_scsv(
509        _io.data("thirdparty") / "Skemer2016_ShearStrainAngles.scsv"
510    )
511    for study, marker, color, fillstyle, label in zip(
512        studies, markers, colors, fillstyles, labels, strict=True
513    ):
514        # Note: np.nonzero returns a tuple.
515        indices = np.nonzero(
516            np.logical_and(
517                np.asarray(data_Skemer2016.study) == study,
518                np.asarray(data_Skemer2016.fabric) == fabric_map[fabric],
519            )
520        )[0]
521        ax.plot(
522            np.take(data_Skemer2016.shear_strain, indices) / 200,
523            np.take(data_Skemer2016.angle, indices),
524            marker=marker,
525            fillstyle=fillstyle,
526            linestyle="none",
527            color=color,
528            label=label,
529        )
530    if not all(b is None for b in labels):
531        _utils.redraw_legend(ax)
532    return fig, ax, colors, data_Skemer2016, indices
533
534
535def spin(
536    ax,
537    initial_angles,
538    rotation_rates,
539    target_initial_angles=None,
540    target_rotation_rates=None,
541    labels=("target", "computed"),
542    shear_axis=None,
543) -> tuple:
544    """Plot rotation rates of grains with known, unique initial [100] angles from X.
545
546    If `ax` is None, a new figure and axes are created with `figure_unless`.
547    The default labels ("target", "computed") can also be overriden.
548    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
549    (and its reflection around 180°).
550
551    Returns a tuple of the figure handle, the axes handle and the set of colors used for
552    the data series plots.
553
554    """
555    fig, ax = figure_unless(ax)
556    ax.set_ylabel("Rotation rate (°/s)")
557    ax.set_xlabel("Initial [100] angle (°)")
558    ax.set_xlim((0, 360))
559    ax.set_xticks(np.linspace(0, 360, 9))
560    if shear_axis is not None:
561        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
562        ax.axvline(
563            (shear_axis + 180) % 360,
564            color="k",
565            linestyle="--",
566            alpha=0.5,
567            label="shear axis",
568        )
569    colors = []
570    if target_rotation_rates is not None:
571        lines = ax.plot(
572            target_initial_angles,
573            target_rotation_rates,
574            c="tab:orange",
575            label=labels[0],
576        )
577        colors.append(lines[0].get_color())
578    series = ax.scatter(
579        initial_angles,
580        rotation_rates,
581        facecolors="none",
582        edgecolors=plt.rcParams["axes.edgecolor"],
583        label=labels[1],
584    )
585    colors.append(series.get_edgecolors()[0])
586    _utils.redraw_legend(ax)
587    return fig, ax, colors
588
589
590def growth(
591    ax,
592    initial_angles,
593    fractions_diff,
594    target_initial_angles=None,
595    target_fractions_diff=None,
596    labels=("target", "computed"),
597    shear_axis=None,
598) -> tuple:
599    """Plot grain growth of grains with known, unique initial [100] angles from X.
600
601    If `ax` is None, a new figure and axes are created with `figure_unless`.
602    The default labels ("target", "computed") can also be overriden.
603    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
604    (and its reflection around 180°).
605
606    Returns a tuple of the figure handle, the axes handle and the set of colors used for
607    the data series plots.
608
609    """
610    fig, ax = figure_unless(ax)
611    ax.set_ylabel("Grain growth rate (s⁻¹)")
612    ax.set_xlabel("Initial [100] angle (°)")
613    ax.set_xlim((0, 360))
614    ax.set_xticks(np.linspace(0, 360, 9))
615    if shear_axis is not None:
616        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
617        ax.axvline(
618            (shear_axis + 180) % 360,
619            color="k",
620            linestyle="--",
621            alpha=0.5,
622            label="shear axis",
623        )
624    colors = []
625    if target_fractions_diff is not None:
626        lines = ax.plot(
627            target_initial_angles,
628            target_fractions_diff,
629            c="tab:orange",
630            label=labels[0],
631        )
632        colors.append(lines[0].get_color())
633    series = ax.scatter(
634        initial_angles,
635        fractions_diff,
636        facecolors="none",
637        edgecolors=plt.rcParams["axes.edgecolor"],
638        label=labels[1],
639    )
640    colors.append(series.get_edgecolors()[0])
641    _utils.redraw_legend(ax)
642    return fig, ax, colors
643
644
645def figure_unless(ax: plt.Axes | None) -> tuple[plt.Figure, plt.Axes]:
646    """Create figure and axes if `ax` is None, or return existing figure for `ax`.
647
648    If `ax` is None, a new figure is created for the axes with a few opinionated default
649    settings (grid, constrained layout, high DPI).
650
651    Returns a tuple containing the figure handle and the axes object.
652
653    """
654    fig: plt.Figure | None
655    if ax is None:
656        fig = plt.figure()
657        ax = fig.add_subplot()
658    else:
659        fig = ax.get_figure()
660    assert fig is not None
661    return fig, ax
662
663
664def figure(figscale: tuple[float, float] | None = None, **kwargs) -> plt.Figure:
665    """Create new figure with a few opinionated default settings.
666
667    (e.g. grid, constrained layout, high DPI).
668
669    The keyword argument `figscale` can be used to scale the figure width and height
670    relative to the default values by passing a tuple. Any additional keyword arguments
671    are passed to `matplotlib.pyplot.figure()`.
672
673    """
674    # NOTE: Opinionated defaults are set using rcParams at the top of this file.
675    _figsize = kwargs.pop("figsize", (DEFAULT_FIG_WIDTH, DEFAULT_FIG_HEIGHT))
676    if figscale is not None:
677        _figsize = (DEFAULT_FIG_WIDTH * figscale[0], DEFAULT_FIG_HEIGHT * figscale[1])
678    return plt.figure(figsize=_figsize, **kwargs)
def default_tick_formatter(x, pos):
35def default_tick_formatter(x, pos):
36    return f"{x/1e3:.1f}"
def polefigures( orientations, ref_axes, i_range, density=False, savefile='polefigures.png', strains=None, **kwargs):
 39def polefigures(
 40    orientations,
 41    ref_axes,
 42    i_range,
 43    density=False,
 44    savefile="polefigures.png",
 45    strains=None,
 46    **kwargs,
 47):
 48    """Plot pole figures of a series of (Nx3x3) orientation matrix stacks.
 49
 50    Produces [100], [010] and [001] pole figures for (resampled) orientations.
 51    For the argument specification, check the output of `pydrex-polefigures --help`
 52    on the command line.
 53
 54    """
 55    if len(orientations) != len(i_range):
 56        raise ValueError("mismatched length of 'orientations' and 'i_range'")
 57    if strains is not None and len(strains) != len(i_range):
 58        raise ValueError("mismatched length of 'strains'")
 59    n_orientations = len(orientations)
 60    fig = plt.figure(figsize=(n_orientations, 4), dpi=600)
 61
 62    if len(i_range) == 1:
 63        grid = fig.add_gridspec(3, n_orientations, hspace=0, wspace=0.2)
 64        first_row = 0
 65    else:
 66        grid = fig.add_gridspec(
 67            4, n_orientations, height_ratios=((1, 3, 3, 3)), hspace=0, wspace=0.2
 68        )
 69        fig_strain = fig.add_subfigure(grid[0, :])
 70        first_row = 1
 71        ax_strain = fig_strain.add_subplot(111)
 72
 73        if strains is None:
 74            fig_strain.suptitle(
 75                f"N ⋅ (max strain) / {i_range.stop}", x=0.5, y=0.85, fontsize="small"
 76            )
 77            ax_strain.set_xlim(
 78                (i_range.start - i_range.step / 2, i_range.stop - i_range.step / 2)
 79            )
 80            ax_strain.set_xticks(list(i_range))
 81        else:
 82            fig_strain.suptitle("strain (%)", x=0.5, y=0.85, fontsize="small")
 83            ax_strain.set_xticks(strains)
 84            ax_strain.set_xlim(
 85                (
 86                    strains[0] - strains[1] / 2,
 87                    strains[-1] + strains[1] / 2,
 88                )
 89            )
 90
 91        ax_strain.set_frame_on(False)
 92        ax_strain.grid(False)
 93        ax_strain.yaxis.set_visible(False)
 94        ax_strain.xaxis.set_tick_params(labelsize="x-small", length=0)
 95
 96    fig100 = fig.add_subfigure(
 97        grid[first_row, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
 98    )
 99    fig100.suptitle("[100]", fontsize="small")
100    fig010 = fig.add_subfigure(
101        grid[first_row + 1, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
102    )
103    fig010.suptitle("[010]", fontsize="small")
104    fig001 = fig.add_subfigure(
105        grid[first_row + 2, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
106    )
107    fig001.suptitle("[001]", fontsize="small")
108    for n, orientations in enumerate(orientations):
109        ax100 = fig100.add_subplot(
110            1, n_orientations, n + 1, projection="pydrex.polefigure"
111        )
112        pf100 = ax100.polefigure(
113            orientations,
114            hkl=[1, 0, 0],
115            ref_axes=ref_axes,
116            density=density,
117            density_kwargs=kwargs,
118        )
119        ax010 = fig010.add_subplot(
120            1, n_orientations, n + 1, projection="pydrex.polefigure"
121        )
122        pf010 = ax010.polefigure(
123            orientations,
124            hkl=[0, 1, 0],
125            ref_axes=ref_axes,
126            density=density,
127            density_kwargs=kwargs,
128        )
129        ax001 = fig001.add_subplot(
130            1, n_orientations, n + 1, projection="pydrex.polefigure"
131        )
132        pf001 = ax001.polefigure(
133            orientations,
134            hkl=[0, 0, 1],
135            ref_axes=ref_axes,
136            density=density,
137            density_kwargs=kwargs,
138        )
139        if density:
140            for ax, pf in zip(
141                (ax100, ax010, ax001), (pf100, pf010, pf001), strict=True
142            ):
143                cbar = fig.colorbar(
144                    pf,
145                    ax=ax,
146                    fraction=0.05,
147                    location="bottom",
148                    orientation="horizontal",
149                )
150                cbar.ax.xaxis.set_tick_params(labelsize="xx-small")
151
152    fig.savefig(_io.resolve_path(savefile))

Plot pole figures of a series of (Nx3x3) orientation matrix stacks.

Produces [100], [010] and [001] pole figures for (resampled) orientations. For the argument specification, check the output of pydrex-polefigures --help on the command line.

def steady_box2d( ax: matplotlib.axes._axes.Axes | None, velocity: tuple, geometry: tuple, ref_axes: str, cpo: tuple | None, colors, aspect='equal', cmap=<matplotlib.colors.ListedColormap object>, marker='.', tick_formatter=<function default_tick_formatter>, label_suffix='(km)', **kwargs) -> tuple:
155def steady_box2d(
156    ax: plt.Axes | None,
157    velocity: tuple,
158    geometry: tuple,
159    ref_axes: str,
160    cpo: tuple | None,
161    colors,
162    aspect="equal",
163    cmap=cmc.batlow,
164    marker=".",
165    tick_formatter=default_tick_formatter,
166    label_suffix="(km)",
167    **kwargs,
168) -> tuple:
169    """Plot pathlines and steady-state velocity arrows for a 2D box domain.
170
171    If `ax` is None, a new figure and axes are created with `figure_unless`.
172
173    - `velocity` — tuple containing a velocity callable¹ and the 2D resolution of the
174      velocity arrow grid, e.g. [20, 20] for 20x20 arrows over the rectangular domain
175    - `geometry` — tuple containing the array of 3D pathline positions and two 2D
176      coordinates (of the lower-left and upper-right domain corners)
177    - `ref_axes` — two letters from {"x", "y", "z"} used to label the horizontal and
178      vertical axes (these also define the projection for the 3D velocity/position)
179    - `cpo` — tuple containing one array of CPO strengths and one of 3D CPO vectors;
180      alternatively set this to `None` and use `marker` to only plot pathline positions
181    - `colors` — monotonic, increasing values along the pathline (e.g. time or strain)
182    - `aspect` — optional, see `matplotlib.axes.Axes.set_aspect`
183    - `cmap` — optional custom color map for `colors`
184    - `marker` — optional pathline position marker used when `cpo` is `None`
185    - `tick_formatter` — optional custom tick formatter callable
186    - `label_suffix` — optional suffix added to the axes labels
187
188    ¹with signature `f(t, x)` where `t` is not used and `x` is a 3D position vector
189
190    Additional keyword arguments are passed to the `matplotlib.axes.Axes.quiver` call
191    used to plot the velocity vectors.
192
193    Returns the figure handle, the axes handle, the quiver collection (velocities) and
194    the scatter collection (pathline).
195
196    """
197    fig, ax = figure_unless(ax)
198    ax.set_xlabel(f"{ref_axes[0]} {label_suffix}")
199    ax.set_ylabel(f"{ref_axes[1]} {label_suffix}")
200
201    get_velocity, resolution = velocity
202    positions, min_coords, max_coords = geometry
203    x_min, y_min = min_coords
204    x_max, y_max = max_coords
205
206    ax.set_xlim((x_min, x_max))
207    ax.set_ylim((y_min, y_max))
208    ax.set_aspect(aspect)
209    ax.xaxis.set_major_formatter(tick_formatter)
210    ax.yaxis.set_major_formatter(tick_formatter)
211
212    _ref_axes = ref_axes.lower()
213    axes_map = {"x": 0, "y": 1, "z": 2}
214    horizontal = axes_map[_ref_axes[0]]
215    vertical = axes_map[_ref_axes[1]]
216
217    velocities = None
218    if resolution is not None:
219        x_res, y_res = resolution
220        X = np.linspace(x_min, x_max, x_res)
221        Y = np.linspace(y_min, y_max, y_res)
222        X_grid, Y_grid = np.meshgrid(X, Y)
223
224        U = np.zeros_like(X_grid.ravel())
225        V = np.zeros_like(Y_grid.ravel())
226        for i, (x, y) in enumerate(zip(X_grid.ravel(), Y_grid.ravel(), strict=True)):
227            p = np.zeros(3)
228            p[horizontal] = x
229            p[vertical] = y
230            v3d = get_velocity(np.nan, p)
231            U[i] = v3d[horizontal]
232            V[i] = v3d[vertical]
233
234        velocities = ax.quiver(
235            X_grid,
236            Y_grid,
237            U.reshape(X_grid.shape),
238            V.reshape(Y_grid.shape),
239            pivot=kwargs.pop("pivot", "mid"),
240            alpha=kwargs.pop("alpha", 0.25),
241            **kwargs,
242        )
243
244    dummy_dim = ({0, 1, 2} - set(_geo.to_indices2d(*ref_axes))).pop()
245    xi_2D = np.asarray([_utils.remove_dim(p, dummy_dim) for p in positions])
246    qcoll: plt.Quiver | plt.PathCollection
247    if cpo is None:
248        qcoll = ax.scatter(xi_2D[:, 0], xi_2D[:, 1], marker=marker, c=colors, cmap=cmap)
249    else:
250        cpo_strengths, cpo_vectors = cpo
251        cpo_2D = np.asarray(
252            [
253                s * _utils.remove_dim(v, dummy_dim)
254                for s, v in zip(cpo_strengths, cpo_vectors, strict=True)
255            ]
256        )
257        qcoll = ax.quiver(
258            xi_2D[:, 0],
259            xi_2D[:, 1],
260            cpo_2D[:, 0],
261            cpo_2D[:, 1],
262            colors,
263            cmap=cmap,
264            pivot="mid",
265            width=kwargs.pop("width", 3e-3),
266            headaxislength=0,
267            headlength=0,
268            zorder=kwargs.pop("zorder", 10) + 1,  # Always above velocity vectors.
269        )
270    return fig, ax, velocities, qcoll

Plot pathlines and steady-state velocity arrows for a 2D box domain.

If ax is None, a new figure and axes are created with figure_unless.

  • velocity — tuple containing a velocity callable¹ and the 2D resolution of the velocity arrow grid, e.g. [20, 20] for 20x20 arrows over the rectangular domain
  • geometry — tuple containing the array of 3D pathline positions and two 2D coordinates (of the lower-left and upper-right domain corners)
  • ref_axes — two letters from {"x", "y", "z"} used to label the horizontal and vertical axes (these also define the projection for the 3D velocity/position)
  • cpo — tuple containing one array of CPO strengths and one of 3D CPO vectors; alternatively set this to None and use marker to only plot pathline positions
  • colors — monotonic, increasing values along the pathline (e.g. time or strain)
  • aspect — optional, see matplotlib.axes.Axes.set_aspect
  • cmap — optional custom color map for colors
  • marker — optional pathline position marker used when cpo is None
  • tick_formatter — optional custom tick formatter callable
  • label_suffix — optional suffix added to the axes labels

¹with signature f(t, x) where t is not used and x is a 3D position vector

Additional keyword arguments are passed to the matplotlib.axes.Axes.quiver call used to plot the velocity vectors.

Returns the figure handle, the axes handle, the quiver collection (velocities) and the scatter collection (pathline).

def alignment( ax: matplotlib.axes._axes.Axes | None, strains: numpy.ndarray, angles: numpy.ndarray, markers: list[str] | tuple[str], labels: list[str] | tuple[str], err: numpy.ndarray | None = None, θ_max: int = 90, θ_fse: numpy.ndarray | None = None, colors: numpy.ndarray | None = None, cmaps=None, **kwargs) -> tuple:
273def alignment(
274    ax: plt.Axes | None,
275    strains: np.ndarray,
276    angles: np.ndarray,
277    markers: list[str] | tuple[str],
278    labels: list[str] | tuple[str],
279    err: np.ndarray | None = None,
280    θ_max: int = 90,
281    θ_fse: np.ndarray | None = None,
282    colors: np.ndarray | None = None,
283    cmaps=None,
284    **kwargs,
285) -> tuple:
286    """Plot `angles` (in degrees) versus `strains` on the given axis.
287
288    Alignment angles could be either bingham averages or the a-axis in the hexagonal
289    symmetry projection, measured from e.g. the shear direction. In the first case,
290    they should be calculated from resampled grain orientations. Expects as many
291    `markers` and `labels` as there are data series in `angles`.
292
293    If `ax` is None, a new figure and axes are created with `figure_unless`.
294
295    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
296      a 2D array of multiple strain series
297    - `angles` — Y-values, may be a 2D array of multiple angle series
298    - `markers` — MatPlotLib markers to use for the data series
299    - `labels` — labels to use for the data series
300    - `err` (optional) — standard errors for the `angles`, shapes must match
301    - `θ_max` — maximum angle (°) to show on the plot, should be less than 90
302    - `θ_fse` (optional) — an array of angles from the long axis of the finite strain
303      ellipsoid to the reference direction (e.g. shear direction)
304    - `colors` (optional) — color coordinates for series of angles
305    - `cmaps` (optional) — color maps for `colors`
306
307    If `colors` and `cmaps` are used, then angle values are colored individually within
308    each angle series.
309
310    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
311    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
312
313    Returns a tuple of the figure handle, the axes handle and the set of colors used for
314    the data series plots.
315
316    """
317    _strains = np.atleast_2d(strains)
318    _angles = np.atleast_2d(angles)
319    if err is not None:
320        _angles_err = np.atleast_2d(err)
321    if not np.all(_strains.shape == _angles.shape):
322        # Assume strains are all the same for each series in `angles`, try np.tile().
323        _strains = np.tile(_strains, (len(_angles), 1))
324
325    fig, ax = figure_unless(ax)
326    ax.set_ylabel(r"$\overline{θ}$ ∈ [0, 90]°")
327    ax.set_ylim((0, θ_max))
328    ax.set_xlabel("Strain (ε)")
329    ax.set_xlim((np.min(strains), np.max(strains)))
330    _colors = []
331    for i, (strains, θ_cpo, marker, label) in enumerate(
332        zip(_strains, _angles, markers, labels, strict=True)
333    ):
334        if colors is not None:
335            ax.scatter(
336                strains,
337                θ_cpo,
338                marker=marker,
339                label=label,
340                c=colors[i],
341                cmap=cmaps[i],
342                alpha=kwargs.pop("alpha", 0.6),
343                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
344                **kwargs,
345            )
346            _colors.append(colors[i])
347        else:
348            lines = ax.plot(strains, θ_cpo, marker, alpha=0.6, label=label, **kwargs)
349            _colors.append(lines[0].get_color())
350        if err is not None:
351            ax.fill_between(
352                strains,
353                θ_cpo - _angles_err[i],
354                θ_cpo + _angles_err[i],
355                alpha=0.22,
356                color=_colors[i],
357            )
358
359    if θ_fse is not None:
360        ax.plot(strains, θ_fse, linestyle=(0, (5, 5)), alpha=0.6, label="FSE")
361    if not all(b is None for b in labels):
362        _utils.redraw_legend(ax)
363    return fig, ax, _colors

Plot angles (in degrees) versus strains on the given axis.

Alignment angles could be either bingham averages or the a-axis in the hexagonal symmetry projection, measured from e.g. the shear direction. In the first case, they should be calculated from resampled grain orientations. Expects as many markers and labels as there are data series in angles.

If ax is None, a new figure and axes are created with figure_unless.

  • strains — X-values, accumulated strain (tensorial) during CPO evolution, may be a 2D array of multiple strain series
  • angles — Y-values, may be a 2D array of multiple angle series
  • markers — MatPlotLib markers to use for the data series
  • labels — labels to use for the data series
  • err (optional) — standard errors for the angles, shapes must match
  • θ_max — maximum angle (°) to show on the plot, should be less than 90
  • θ_fse (optional) — an array of angles from the long axis of the finite strain ellipsoid to the reference direction (e.g. shear direction)
  • colors (optional) — color coordinates for series of angles
  • cmaps (optional) — color maps for colors

If colors and cmaps are used, then angle values are colored individually within each angle series.

Additional keyword arguments are passed to matplotlib.axes.Axes.scatter if colors is not None, or to matplotlib.axes.Axes.plot otherwise.

Returns a tuple of the figure handle, the axes handle and the set of colors used for the data series plots.

def strengths( ax: matplotlib.axes._axes.Axes | None, strains: numpy.ndarray, strengths: numpy.ndarray, ylabel: str, markers: list[str] | tuple[str], labels: list[str] | tuple[str], err: numpy.ndarray | None = None, cpo_threshold: float | None = None, colors: numpy.ndarray | None = None, cmaps=None, **kwargs):
366def strengths(
367    ax: plt.Axes | None,
368    strains: np.ndarray,
369    strengths: np.ndarray,
370    ylabel: str,
371    markers: list[str] | tuple[str],
372    labels: list[str] | tuple[str],
373    err: np.ndarray | None = None,
374    cpo_threshold: float | None = None,
375    colors: np.ndarray | None = None,
376    cmaps=None,
377    **kwargs,
378):
379    """Plot CPO `strengths` (e.g. M-indices) versus `strains` on the given axis.
380
381    If `ax` is None, a new figure and axes are created with `figure_unless`.
382
383    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
384      a 2D array of multiple strain series
385    - `strengths` — Y-values, may be a 2D array of multiple strength series
386    - `ylabel` — label for the Y axis, depending on chosen texture strength measure
387    - `markers` — MatPlotLib markers to use for the data series
388    - `labels` — labels to use for the data series
389    - `err` (optional) — standard errors for the `strengths`, shapes must match
390    - `colors` (optional) — color coordinates for series of strengths
391    - `cpo_threshold` (optional) — plot a dashed line at this threshold
392    - `cmaps` — color maps for `colors`
393
394    If `colors` and `cmaps` are used, then strength values are colored individually
395    within each strength series.
396
397    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
398    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
399
400    Returns a tuple of the figure handle, the axes handle and the set of colors used for
401    the data series plots.
402
403    """
404    _strains = np.atleast_2d(strains)
405    _strengths = np.atleast_2d(strengths)
406    if err is not None:
407        _strengths_err = np.atleast_2d(err)
408    if not np.all(_strains.shape == _strengths.shape):
409        # Assume strains are all the same for each series in `strengths`, try np.tile().
410        _strains = np.tile(_strains, (len(_strengths), 1))
411
412    fig, ax = figure_unless(ax)
413    ax.set_ylabel(ylabel)
414    ax.set_xlabel("Strain (ε)")
415    ax.set_xlim((np.min(strains), np.max(strains)))
416
417    if cpo_threshold is not None:
418        ax.axhline(cpo_threshold, color=plt.rcParams["axes.edgecolor"], linestyle="--")
419
420    _colors = []
421    for i, (strains, strength, marker, label) in enumerate(
422        zip(_strains, _strengths, markers, labels, strict=True)
423    ):
424        if colors is not None:
425            ax.scatter(
426                strains,
427                strength,
428                marker=marker,
429                label=label,
430                c=colors[i],
431                cmap=cmaps[i],
432                alpha=kwargs.pop("alpha", 0.6),
433                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
434                **kwargs,
435            )
436            _colors.append(colors[i])
437        else:
438            lines = ax.plot(
439                strains, strength, marker, alpha=0.33, label=label, **kwargs
440            )
441            _colors.append(lines[0].get_color())
442        if err is not None:
443            ax.fill_between(
444                strains,
445                strength - _strengths_err[i],
446                strength + _strengths_err[i],
447                alpha=0.22,
448                color=_colors[i],
449            )
450
451    if not all(b is None for b in labels):
452        _utils.redraw_legend(ax)
453    return fig, ax, _colors

Plot CPO strengths (e.g. M-indices) versus strains on the given axis.

If ax is None, a new figure and axes are created with figure_unless.

  • strains — X-values, accumulated strain (tensorial) during CPO evolution, may be a 2D array of multiple strain series
  • strengths — Y-values, may be a 2D array of multiple strength series
  • ylabel — label for the Y axis, depending on chosen texture strength measure
  • markers — MatPlotLib markers to use for the data series
  • labels — labels to use for the data series
  • err (optional) — standard errors for the strengths, shapes must match
  • colors (optional) — color coordinates for series of strengths
  • cpo_threshold (optional) — plot a dashed line at this threshold
  • cmaps — color maps for colors

If colors and cmaps are used, then strength values are colored individually within each strength series.

Additional keyword arguments are passed to matplotlib.axes.Axes.scatter if colors is not None, or to matplotlib.axes.Axes.plot otherwise.

Returns a tuple of the figure handle, the axes handle and the set of colors used for the data series plots.

def grainsizes(ax, strains, fractions) -> tuple:
456def grainsizes(ax, strains, fractions) -> tuple:
457    """Plot grain volume `fractions` versus `strains` on the given axis.
458
459    If `ax` is None, a new figure and axes are created with `figure_unless`.
460
461    """
462    n_grains = len(fractions[0])
463    fig, ax = figure_unless(ax)
464    ax.set_ylabel(r"$\log_{10}(f × N)$")
465    ax.set_xlabel("Strain (ε)")
466    parts = ax.violinplot(
467        [np.log10(f * n_grains) for f in fractions], positions=strains, widths=0.8
468    )
469    for part in parts["bodies"]:
470        part.set_color("black")
471        part.set_alpha(1)
472    parts["cbars"].set_alpha(0)
473    parts["cmins"].set_visible(False)
474    parts["cmaxes"].set_visible(False)
475    # parts["cmaxes"].set_color("red")
476    # parts["cmaxes"].set_alpha(0.5)
477    return fig, ax, parts

Plot grain volume fractions versus strains on the given axis.

If ax is None, a new figure and axes are created with figure_unless.

def show_Skemer2016_ShearStrainAngles(ax, studies, markers, colors, fillstyles, labels, fabric) -> tuple:
480def show_Skemer2016_ShearStrainAngles(
481    ax, studies, markers, colors, fillstyles, labels, fabric
482) -> tuple:
483    """Show data from `src/pydrex/data/thirdparty/Skemer2016_ShearStrainAngles.scsv`.
484
485    Plot data from the Skemer 2016 datafile on the axis given by `ax`. Select the
486    studies from which to plot the data, which must be a list of strings with exact
487    matches in the `study` column in the datafile.
488    Also filter the data to select only the given `fabric`
489    (see `pydrex.core.MineralFabric`).
490
491    If `ax` is None, a new figure and axes are created with `figure_unless`.
492
493    Returns a tuple containing:
494    - the figure handle
495    - the axes handle
496    - the set of colors used for the data series plots
497    - the Skemer 2016 dataset
498    - the indices used to select data according to the "studies" and "fabric" filters
499
500    """
501    fabric_map = {
502        _core.MineralFabric.olivine_A: "A",
503        _core.MineralFabric.olivine_B: "B",
504        _core.MineralFabric.olivine_C: "C",
505        _core.MineralFabric.olivine_D: "D",
506        _core.MineralFabric.olivine_E: "E",
507    }
508    fig, ax = figure_unless(ax)
509    data_Skemer2016 = _io.read_scsv(
510        _io.data("thirdparty") / "Skemer2016_ShearStrainAngles.scsv"
511    )
512    for study, marker, color, fillstyle, label in zip(
513        studies, markers, colors, fillstyles, labels, strict=True
514    ):
515        # Note: np.nonzero returns a tuple.
516        indices = np.nonzero(
517            np.logical_and(
518                np.asarray(data_Skemer2016.study) == study,
519                np.asarray(data_Skemer2016.fabric) == fabric_map[fabric],
520            )
521        )[0]
522        ax.plot(
523            np.take(data_Skemer2016.shear_strain, indices) / 200,
524            np.take(data_Skemer2016.angle, indices),
525            marker=marker,
526            fillstyle=fillstyle,
527            linestyle="none",
528            color=color,
529            label=label,
530        )
531    if not all(b is None for b in labels):
532        _utils.redraw_legend(ax)
533    return fig, ax, colors, data_Skemer2016, indices

Show data from src/pydrex/data/thirdparty/Skemer2016_ShearStrainAngles.scsv.

Plot data from the Skemer 2016 datafile on the axis given by ax. Select the studies from which to plot the data, which must be a list of strings with exact matches in the study column in the datafile. Also filter the data to select only the given fabric (see pydrex.core.MineralFabric).

If ax is None, a new figure and axes are created with figure_unless.

Returns a tuple containing:

  • the figure handle
  • the axes handle
  • the set of colors used for the data series plots
  • the Skemer 2016 dataset
  • the indices used to select data according to the "studies" and "fabric" filters
def spin( ax, initial_angles, rotation_rates, target_initial_angles=None, target_rotation_rates=None, labels=('target', 'computed'), shear_axis=None) -> tuple:
536def spin(
537    ax,
538    initial_angles,
539    rotation_rates,
540    target_initial_angles=None,
541    target_rotation_rates=None,
542    labels=("target", "computed"),
543    shear_axis=None,
544) -> tuple:
545    """Plot rotation rates of grains with known, unique initial [100] angles from X.
546
547    If `ax` is None, a new figure and axes are created with `figure_unless`.
548    The default labels ("target", "computed") can also be overriden.
549    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
550    (and its reflection around 180°).
551
552    Returns a tuple of the figure handle, the axes handle and the set of colors used for
553    the data series plots.
554
555    """
556    fig, ax = figure_unless(ax)
557    ax.set_ylabel("Rotation rate (°/s)")
558    ax.set_xlabel("Initial [100] angle (°)")
559    ax.set_xlim((0, 360))
560    ax.set_xticks(np.linspace(0, 360, 9))
561    if shear_axis is not None:
562        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
563        ax.axvline(
564            (shear_axis + 180) % 360,
565            color="k",
566            linestyle="--",
567            alpha=0.5,
568            label="shear axis",
569        )
570    colors = []
571    if target_rotation_rates is not None:
572        lines = ax.plot(
573            target_initial_angles,
574            target_rotation_rates,
575            c="tab:orange",
576            label=labels[0],
577        )
578        colors.append(lines[0].get_color())
579    series = ax.scatter(
580        initial_angles,
581        rotation_rates,
582        facecolors="none",
583        edgecolors=plt.rcParams["axes.edgecolor"],
584        label=labels[1],
585    )
586    colors.append(series.get_edgecolors()[0])
587    _utils.redraw_legend(ax)
588    return fig, ax, colors

Plot rotation rates of grains with known, unique initial [100] angles from X.

If ax is None, a new figure and axes are created with figure_unless. The default labels ("target", "computed") can also be overriden. If shear_axis is not None, a dashed line will be drawn at the given x-value (and its reflection around 180°).

Returns a tuple of the figure handle, the axes handle and the set of colors used for the data series plots.

def growth( ax, initial_angles, fractions_diff, target_initial_angles=None, target_fractions_diff=None, labels=('target', 'computed'), shear_axis=None) -> tuple:
591def growth(
592    ax,
593    initial_angles,
594    fractions_diff,
595    target_initial_angles=None,
596    target_fractions_diff=None,
597    labels=("target", "computed"),
598    shear_axis=None,
599) -> tuple:
600    """Plot grain growth of grains with known, unique initial [100] angles from X.
601
602    If `ax` is None, a new figure and axes are created with `figure_unless`.
603    The default labels ("target", "computed") can also be overriden.
604    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
605    (and its reflection around 180°).
606
607    Returns a tuple of the figure handle, the axes handle and the set of colors used for
608    the data series plots.
609
610    """
611    fig, ax = figure_unless(ax)
612    ax.set_ylabel("Grain growth rate (s⁻¹)")
613    ax.set_xlabel("Initial [100] angle (°)")
614    ax.set_xlim((0, 360))
615    ax.set_xticks(np.linspace(0, 360, 9))
616    if shear_axis is not None:
617        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
618        ax.axvline(
619            (shear_axis + 180) % 360,
620            color="k",
621            linestyle="--",
622            alpha=0.5,
623            label="shear axis",
624        )
625    colors = []
626    if target_fractions_diff is not None:
627        lines = ax.plot(
628            target_initial_angles,
629            target_fractions_diff,
630            c="tab:orange",
631            label=labels[0],
632        )
633        colors.append(lines[0].get_color())
634    series = ax.scatter(
635        initial_angles,
636        fractions_diff,
637        facecolors="none",
638        edgecolors=plt.rcParams["axes.edgecolor"],
639        label=labels[1],
640    )
641    colors.append(series.get_edgecolors()[0])
642    _utils.redraw_legend(ax)
643    return fig, ax, colors

Plot grain growth of grains with known, unique initial [100] angles from X.

If ax is None, a new figure and axes are created with figure_unless. The default labels ("target", "computed") can also be overriden. If shear_axis is not None, a dashed line will be drawn at the given x-value (and its reflection around 180°).

Returns a tuple of the figure handle, the axes handle and the set of colors used for the data series plots.

def figure_unless( ax: matplotlib.axes._axes.Axes | None) -> tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes]:
646def figure_unless(ax: plt.Axes | None) -> tuple[plt.Figure, plt.Axes]:
647    """Create figure and axes if `ax` is None, or return existing figure for `ax`.
648
649    If `ax` is None, a new figure is created for the axes with a few opinionated default
650    settings (grid, constrained layout, high DPI).
651
652    Returns a tuple containing the figure handle and the axes object.
653
654    """
655    fig: plt.Figure | None
656    if ax is None:
657        fig = plt.figure()
658        ax = fig.add_subplot()
659    else:
660        fig = ax.get_figure()
661    assert fig is not None
662    return fig, ax

Create figure and axes if ax is None, or return existing figure for ax.

If ax is None, a new figure is created for the axes with a few opinionated default settings (grid, constrained layout, high DPI).

Returns a tuple containing the figure handle and the axes object.

def figure( figscale: tuple[float, float] | None = None, **kwargs) -> matplotlib.figure.Figure:
665def figure(figscale: tuple[float, float] | None = None, **kwargs) -> plt.Figure:
666    """Create new figure with a few opinionated default settings.
667
668    (e.g. grid, constrained layout, high DPI).
669
670    The keyword argument `figscale` can be used to scale the figure width and height
671    relative to the default values by passing a tuple. Any additional keyword arguments
672    are passed to `matplotlib.pyplot.figure()`.
673
674    """
675    # NOTE: Opinionated defaults are set using rcParams at the top of this file.
676    _figsize = kwargs.pop("figsize", (DEFAULT_FIG_WIDTH, DEFAULT_FIG_HEIGHT))
677    if figscale is not None:
678        _figsize = (DEFAULT_FIG_WIDTH * figscale[0], DEFAULT_FIG_HEIGHT * figscale[1])
679    return plt.figure(figsize=_figsize, **kwargs)

Create new figure with a few opinionated default settings.

(e.g. grid, constrained layout, high DPI).

The keyword argument figscale can be used to scale the figure width and height relative to the default values by passing a tuple. Any additional keyword arguments are passed to matplotlib.pyplot.figure().