
PyDRex: Visualisation functions for test outputs and examples.

  1"""> PyDRex: Visualisation functions for test outputs and examples."""
  3import numpy as np
  4from cmcrameri import cm as cmc
  5from matplotlib import projections as mproj
  6from matplotlib import pyplot as plt
  8from pydrex import axes as _axes
  9from pydrex import core as _core
 10from pydrex import geometry as _geo
 11from pydrex import io as _io
 12from pydrex import logger as _log
 13from pydrex import utils as _utils
 15# Use a non-interactive vector-graphics backend by default.
 16plt.rcParams["backend"] = "PDF"
 17# Get default figure size for easy referencing and scaling.
 18DEFAULT_FIG_WIDTH, DEFAULT_FIG_HEIGHT = plt.rcParams["figure.figsize"]
 19plt.rcParams["axes.grid"] = True
 20# Always draw grid behind everything else.
 21plt.rcParams["axes.axisbelow"] = True
 22# Always use constrained layout by default (modern version of tight layout).
 23plt.rcParams["figure.constrained_layout.use"] = True
 24# Use 300 DPI by default, NASA can keep their blurry images.
 25plt.rcParams["figure.dpi"] = 300
 26# Make sure we have the required matplotlib "projections" (really just Axes subclasses).
 27if "pydrex.polefigure" not in mproj.get_projection_names():
 28    _log.warning(
 29        "failed to find pydrex.polefigure projection; it should be registered in %s",
 30        _axes,
 31    )
 34def default_tick_formatter(x, pos):
 35    return f"{x/1e3:.1f}"
 38def polefigures(
 39    orientations,
 40    ref_axes,
 41    i_range,
 42    density=False,
 43    savefile="polefigures.png",
 44    strains=None,
 45    **kwargs,
 47    """Plot pole figures of a series of (Nx3x3) orientation matrix stacks.
 49    Produces [100], [010] and [001] pole figures for (resampled) orientations.
 50    For the argument specification, check the output of `pydrex-polefigures --help`
 51    on the command line.
 53    """
 54    if len(orientations) != len(i_range):
 55        raise ValueError("mismatched length of 'orientations' and 'i_range'")
 56    if strains is not None and len(strains) != len(i_range):
 57        raise ValueError("mismatched length of 'strains'")
 58    n_orientations = len(orientations)
 59    fig = plt.figure(figsize=(n_orientations, 4), dpi=600)
 61    if len(i_range) == 1:
 62        grid = fig.add_gridspec(3, n_orientations, hspace=0, wspace=0.2)
 63        first_row = 0
 64    else:
 65        grid = fig.add_gridspec(
 66            4, n_orientations, height_ratios=((1, 3, 3, 3)), hspace=0, wspace=0.2
 67        )
 68        fig_strain = fig.add_subfigure(grid[0, :])
 69        first_row = 1
 70        ax_strain = fig_strain.add_subplot(111)
 72        if strains is None:
 73            fig_strain.suptitle(
 74                f"N ⋅ (max strain) / {i_range.stop}", x=0.5, y=0.85, fontsize="small"
 75            )
 76            ax_strain.set_xlim(
 77                (i_range.start - i_range.step / 2, i_range.stop - i_range.step / 2)
 78            )
 79            ax_strain.set_xticks(list(i_range))
 80        else:
 81            fig_strain.suptitle("strain (%)", x=0.5, y=0.85, fontsize="small")
 82            ax_strain.set_xticks(strains)
 83            ax_strain.set_xlim(
 84                (
 85                    strains[0] - strains[1] / 2,
 86                    strains[-1] + strains[1] / 2,
 87                )
 88            )
 90        ax_strain.set_frame_on(False)
 91        ax_strain.grid(False)
 92        ax_strain.yaxis.set_visible(False)
 93        ax_strain.xaxis.set_tick_params(labelsize="x-small", length=0)
 95    fig100 = fig.add_subfigure(
 96        grid[first_row, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
 97    )
 98    fig100.suptitle("[100]", fontsize="small")
 99    fig010 = fig.add_subfigure(
100        grid[first_row + 1, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
101    )
102    fig010.suptitle("[010]", fontsize="small")
103    fig001 = fig.add_subfigure(
104        grid[first_row + 2, :], edgecolor=plt.rcParams["grid.color"], linewidth=1
105    )
106    fig001.suptitle("[001]", fontsize="small")
107    for n, orientations in enumerate(orientations):
108        ax100 = fig100.add_subplot(
109            1, n_orientations, n + 1, projection="pydrex.polefigure"
110        )
111        pf100 = ax100.polefigure(
112            orientations,
113            hkl=[1, 0, 0],
114            ref_axes=ref_axes,
115            density=density,
116            density_kwargs=kwargs,
117        )
118        ax010 = fig010.add_subplot(
119            1, n_orientations, n + 1, projection="pydrex.polefigure"
120        )
121        pf010 = ax010.polefigure(
122            orientations,
123            hkl=[0, 1, 0],
124            ref_axes=ref_axes,
125            density=density,
126            density_kwargs=kwargs,
127        )
128        ax001 = fig001.add_subplot(
129            1, n_orientations, n + 1, projection="pydrex.polefigure"
130        )
131        pf001 = ax001.polefigure(
132            orientations,
133            hkl=[0, 0, 1],
134            ref_axes=ref_axes,
135            density=density,
136            density_kwargs=kwargs,
137        )
138        if density:
139            for ax, pf in zip(
140                (ax100, ax010, ax001), (pf100, pf010, pf001), strict=True
141            ):
142                cbar = fig.colorbar(
143                    pf,
144                    ax=ax,
145                    fraction=0.05,
146                    location="bottom",
147                    orientation="horizontal",
148                )
149      "xx-small")
151    fig.savefig(_io.resolve_path(savefile))
154def steady_box2d(
155    ax: plt.Axes | None,
156    velocity: tuple,
157    geometry: tuple,
158    ref_axes: str,
159    cpo: tuple | None,
160    colors,
161    aspect="equal",
162    cmap=cmc.batlow,
163    marker=".",
164    tick_formatter=default_tick_formatter,
165    label_suffix="(km)",
166    **kwargs,
167) -> tuple:
168    """Plot pathlines and steady-state velocity arrows for a 2D box domain.
170    If `ax` is None, a new figure and axes are created with `figure_unless`.
172    - `velocity` — tuple containing a velocity callable¹ and the 2D resolution of the
173      velocity arrow grid, e.g. [20, 20] for 20x20 arrows over the rectangular domain
174    - `geometry` — tuple containing the array of 3D pathline positions and two 2D
175      coordinates (of the lower-left and upper-right domain corners)
176    - `ref_axes` — two letters from {"x", "y", "z"} used to label the horizontal and
177      vertical axes (these also define the projection for the 3D velocity/position)
178    - `cpo` — tuple containing one array of CPO strengths and one of 3D CPO vectors;
179      alternatively set this to `None` and use `marker` to only plot pathline positions
180    - `colors` — monotonic, increasing values along the pathline (e.g. time or strain)
181    - `aspect` — optional, see `matplotlib.axes.Axes.set_aspect`
182    - `cmap` — optional custom color map for `colors`
183    - `marker` — optional pathline position marker used when `cpo` is `None`
184    - `tick_formatter` — optional custom tick formatter callable
185    - `label_suffix` — optional suffix added to the axes labels
187    ¹with signature `f(t, x)` where `t` is not used and `x` is a 3D position vector
189    Additional keyword arguments are passed to the `matplotlib.axes.Axes.quiver` call
190    used to plot the velocity vectors.
192    Returns the figure handle, the axes handle, the quiver collection (velocities) and
193    the scatter collection (pathline).
195    """
196    fig, ax = figure_unless(ax)
197    ax.set_xlabel(f"{ref_axes[0]} {label_suffix}")
198    ax.set_ylabel(f"{ref_axes[1]} {label_suffix}")
200    get_velocity, resolution = velocity
201    positions, min_coords, max_coords = geometry
202    x_min, y_min = min_coords
203    x_max, y_max = max_coords
205    ax.set_xlim((x_min, x_max))
206    ax.set_ylim((y_min, y_max))
207    ax.set_aspect(aspect)
208    ax.xaxis.set_major_formatter(tick_formatter)
209    ax.yaxis.set_major_formatter(tick_formatter)
211    _ref_axes = ref_axes.lower()
212    axes_map = {"x": 0, "y": 1, "z": 2}
213    horizontal = axes_map[_ref_axes[0]]
214    vertical = axes_map[_ref_axes[1]]
216    velocities = None
217    if resolution is not None:
218        x_res, y_res = resolution
219        X = np.linspace(x_min, x_max, x_res)
220        Y = np.linspace(y_min, y_max, y_res)
221        X_grid, Y_grid = np.meshgrid(X, Y)
223        U = np.zeros_like(X_grid.ravel())
224        V = np.zeros_like(Y_grid.ravel())
225        for i, (x, y) in enumerate(zip(X_grid.ravel(), Y_grid.ravel(), strict=True)):
226            p = np.zeros(3)
227            p[horizontal] = x
228            p[vertical] = y
229            v3d = get_velocity(np.nan, p)
230            U[i] = v3d[horizontal]
231            V[i] = v3d[vertical]
233        velocities = ax.quiver(
234            X_grid,
235            Y_grid,
236            U.reshape(X_grid.shape),
237            V.reshape(Y_grid.shape),
238            pivot=kwargs.pop("pivot", "mid"),
239            alpha=kwargs.pop("alpha", 0.25),
240            **kwargs,
241        )
243    dummy_dim = ({0, 1, 2} - set(_geo.to_indices2d(*ref_axes))).pop()
244    xi_2D = np.asarray([_utils.remove_dim(p, dummy_dim) for p in positions])
245    qcoll: plt.Quiver | plt.PathCollection
246    if cpo is None:
247        qcoll = ax.scatter(xi_2D[:, 0], xi_2D[:, 1], marker=marker, c=colors, cmap=cmap)
248    else:
249        cpo_strengths, cpo_vectors = cpo
250        cpo_2D = np.asarray(
251            [
252                s * _utils.remove_dim(v, dummy_dim)
253                for s, v in zip(cpo_strengths, cpo_vectors, strict=True)
254            ]
255        )
256        qcoll = ax.quiver(
257            xi_2D[:, 0],
258            xi_2D[:, 1],
259            cpo_2D[:, 0],
260            cpo_2D[:, 1],
261            colors,
262            cmap=cmap,
263            pivot="mid",
264            width=kwargs.pop("width", 3e-3),
265            headaxislength=0,
266            headlength=0,
267            zorder=kwargs.pop("zorder", 10) + 1,  # Always above velocity vectors.
268        )
269    return fig, ax, velocities, qcoll
272def alignment(
273    ax: plt.Axes | None,
274    strains: np.ndarray,
275    angles: np.ndarray,
276    markers: list[str] | tuple[str],
277    labels: list[str] | tuple[str],
278    err: np.ndarray | None = None,
279    θ_max: int = 90,
280    θ_fse: np.ndarray | None = None,
281    colors: np.ndarray | None = None,
282    cmaps=None,
283    **kwargs,
284) -> tuple:
285    """Plot `angles` (in degrees) versus `strains` on the given axis.
287    Alignment angles could be either bingham averages or the a-axis in the hexagonal
288    symmetry projection, measured from e.g. the shear direction. In the first case,
289    they should be calculated from resampled grain orientations. Expects as many
290    `markers` and `labels` as there are data series in `angles`.
292    If `ax` is None, a new figure and axes are created with `figure_unless`.
294    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
295      a 2D array of multiple strain series
296    - `angles` — Y-values, may be a 2D array of multiple angle series
297    - `markers` — MatPlotLib markers to use for the data series
298    - `labels` — labels to use for the data series
299    - `err` (optional) — standard errors for the `angles`, shapes must match
300    - `θ_max` — maximum angle (°) to show on the plot, should be less than 90
301    - `θ_fse` (optional) — an array of angles from the long axis of the finite strain
302      ellipsoid to the reference direction (e.g. shear direction)
303    - `colors` (optional) — color coordinates for series of angles
304    - `cmaps` (optional) — color maps for `colors`
306    If `colors` and `cmaps` are used, then angle values are colored individually within
307    each angle series.
309    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
310    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
312    Returns a tuple of the figure handle, the axes handle and the set of colors used for
313    the data series plots.
315    """
316    _strains = np.atleast_2d(strains)
317    _angles = np.atleast_2d(angles)
318    if err is not None:
319        _angles_err = np.atleast_2d(err)
320    if not np.all(_strains.shape == _angles.shape):
321        # Assume strains are all the same for each series in `angles`, try np.tile().
322        _strains = np.tile(_strains, (len(_angles), 1))
324    fig, ax = figure_unless(ax)
325    ax.set_ylabel(r"$\overline{θ}$ ∈ [0, 90]°")
326    ax.set_ylim((0, θ_max))
327    ax.set_xlabel("Strain (ε)")
328    ax.set_xlim((np.min(strains), np.max(strains)))
329    _colors = []
330    for i, (strains, θ_cpo, marker, label) in enumerate(
331        zip(_strains, _angles, markers, labels, strict=True)
332    ):
333        if colors is not None:
334            ax.scatter(
335                strains,
336                θ_cpo,
337                marker=marker,
338                label=label,
339                c=colors[i],
340                cmap=cmaps[i],
341                alpha=kwargs.pop("alpha", 0.6),
342                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
343                **kwargs,
344            )
345            _colors.append(colors[i])
346        else:
347            lines = ax.plot(strains, θ_cpo, marker, alpha=0.6, label=label, **kwargs)
348            _colors.append(lines[0].get_color())
349        if err is not None:
350            ax.fill_between(
351                strains,
352                θ_cpo - _angles_err[i],
353                θ_cpo + _angles_err[i],
354                alpha=0.22,
355                color=_colors[i],
356            )
358    if θ_fse is not None:
359        ax.plot(strains, θ_fse, linestyle=(0, (5, 5)), alpha=0.6, label="FSE")
360    if not all(b is None for b in labels):
361        _utils.redraw_legend(ax)
362    return fig, ax, _colors
365def strengths(
366    ax: plt.Axes | None,
367    strains: np.ndarray,
368    strengths: np.ndarray,
369    ylabel: str,
370    markers: list[str] | tuple[str],
371    labels: list[str] | tuple[str],
372    err: np.ndarray | None = None,
373    cpo_threshold: float | None = None,
374    colors: np.ndarray | None = None,
375    cmaps=None,
376    **kwargs,
378    """Plot CPO `strengths` (e.g. M-indices) versus `strains` on the given axis.
380    If `ax` is None, a new figure and axes are created with `figure_unless`.
382    - `strains` — X-values, accumulated strain (tensorial) during CPO evolution, may be
383      a 2D array of multiple strain series
384    - `strengths` — Y-values, may be a 2D array of multiple strength series
385    - `ylabel` — label for the Y axis, depending on chosen texture strength measure
386    - `markers` — MatPlotLib markers to use for the data series
387    - `labels` — labels to use for the data series
388    - `err` (optional) — standard errors for the `strengths`, shapes must match
389    - `colors` (optional) — color coordinates for series of strengths
390    - `cpo_threshold` (optional) — plot a dashed line at this threshold
391    - `cmaps` — color maps for `colors`
393    If `colors` and `cmaps` are used, then strength values are colored individually
394    within each strength series.
396    Additional keyword arguments are passed to `matplotlib.axes.Axes.scatter` if
397    `colors` is not `None`, or to `matplotlib.axes.Axes.plot` otherwise.
399    Returns a tuple of the figure handle, the axes handle and the set of colors used for
400    the data series plots.
402    """
403    _strains = np.atleast_2d(strains)
404    _strengths = np.atleast_2d(strengths)
405    if err is not None:
406        _strengths_err = np.atleast_2d(err)
407    if not np.all(_strains.shape == _strengths.shape):
408        # Assume strains are all the same for each series in `strengths`, try np.tile().
409        _strains = np.tile(_strains, (len(_strengths), 1))
411    fig, ax = figure_unless(ax)
412    ax.set_ylabel(ylabel)
413    ax.set_xlabel("Strain (ε)")
414    ax.set_xlim((np.min(strains), np.max(strains)))
416    if cpo_threshold is not None:
417        ax.axhline(cpo_threshold, color=plt.rcParams["axes.edgecolor"], linestyle="--")
419    _colors = []
420    for i, (strains, strength, marker, label) in enumerate(
421        zip(_strains, _strengths, markers, labels, strict=True)
422    ):
423        if colors is not None:
424            ax.scatter(
425                strains,
426                strength,
427                marker=marker,
428                label=label,
429                c=colors[i],
430                cmap=cmaps[i],
431                alpha=kwargs.pop("alpha", 0.6),
432                edgecolor=kwargs.pop("edgecolor", plt.rcParams["axes.edgecolor"]),
433                **kwargs,
434            )
435            _colors.append(colors[i])
436        else:
437            lines = ax.plot(
438                strains, strength, marker, alpha=0.33, label=label, **kwargs
439            )
440            _colors.append(lines[0].get_color())
441        if err is not None:
442            ax.fill_between(
443                strains,
444                strength - _strengths_err[i],
445                strength + _strengths_err[i],
446                alpha=0.22,
447                color=_colors[i],
448            )
450    if not all(b is None for b in labels):
451        _utils.redraw_legend(ax)
452    return fig, ax, _colors
455def grainsizes(ax, strains, fractions) -> tuple:
456    """Plot grain volume `fractions` versus `strains` on the given axis.
458    If `ax` is None, a new figure and axes are created with `figure_unless`.
460    """
461    n_grains = len(fractions[0])
462    fig, ax = figure_unless(ax)
463    ax.set_ylabel(r"$\log_{10}(f × N)$")
464    ax.set_xlabel("Strain (ε)")
465    parts = ax.violinplot(
466        [np.log10(f * n_grains) for f in fractions], positions=strains, widths=0.8
467    )
468    for part in parts["bodies"]:
469        part.set_color("black")
470        part.set_alpha(1)
471    parts["cbars"].set_alpha(0)
472    parts["cmins"].set_visible(False)
473    parts["cmaxes"].set_visible(False)
474    # parts["cmaxes"].set_color("red")
475    # parts["cmaxes"].set_alpha(0.5)
476    return fig, ax, parts
479def show_Skemer2016_ShearStrainAngles(
480    ax, studies, markers, colors, fillstyles, labels, fabric
481) -> tuple:
482    """Show data from `src/pydrex/data/thirdparty/Skemer2016_ShearStrainAngles.scsv`.
484    Plot data from the Skemer 2016 datafile on the axis given by `ax`. Select the
485    studies from which to plot the data, which must be a list of strings with exact
486    matches in the `study` column in the datafile.
487    Also filter the data to select only the given `fabric`
488    (see `pydrex.core.MineralFabric`).
490    If `ax` is None, a new figure and axes are created with `figure_unless`.
492    Returns a tuple containing:
493    - the figure handle
494    - the axes handle
495    - the set of colors used for the data series plots
496    - the Skemer 2016 dataset
497    - the indices used to select data according to the "studies" and "fabric" filters
499    """
500    fabric_map = {
501        _core.MineralFabric.olivine_A: "A",
502        _core.MineralFabric.olivine_B: "B",
503        _core.MineralFabric.olivine_C: "C",
504        _core.MineralFabric.olivine_D: "D",
505        _core.MineralFabric.olivine_E: "E",
506    }
507    fig, ax = figure_unless(ax)
508    data_Skemer2016 = _io.read_scsv(
509"thirdparty") / "Skemer2016_ShearStrainAngles.scsv"
510    )
511    for study, marker, color, fillstyle, label in zip(
512        studies, markers, colors, fillstyles, labels, strict=True
513    ):
514        # Note: np.nonzero returns a tuple.
515        indices = np.nonzero(
516            np.logical_and(
517                np.asarray( == study,
518                np.asarray(data_Skemer2016.fabric) == fabric_map[fabric],
519            )
520        )[0]
521        ax.plot(
522            np.take(data_Skemer2016.shear_strain, indices) / 200,
523            np.take(data_Skemer2016.angle, indices),
524            marker=marker,
525            fillstyle=fillstyle,
526            linestyle="none",
527            color=color,
528            label=label,
529        )
530    if not all(b is None for b in labels):
531        _utils.redraw_legend(ax)
532    return fig, ax, colors, data_Skemer2016, indices
535def spin(
536    ax,
537    initial_angles,
538    rotation_rates,
539    target_initial_angles=None,
540    target_rotation_rates=None,
541    labels=("target", "computed"),
542    shear_axis=None,
543) -> tuple:
544    """Plot rotation rates of grains with known, unique initial [100] angles from X.
546    If `ax` is None, a new figure and axes are created with `figure_unless`.
547    The default labels ("target", "computed") can also be overriden.
548    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
549    (and its reflection around 180°).
551    Returns a tuple of the figure handle, the axes handle and the set of colors used for
552    the data series plots.
554    """
555    fig, ax = figure_unless(ax)
556    ax.set_ylabel("Rotation rate (°/s)")
557    ax.set_xlabel("Initial [100] angle (°)")
558    ax.set_xlim((0, 360))
559    ax.set_xticks(np.linspace(0, 360, 9))
560    if shear_axis is not None:
561        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
562        ax.axvline(
563            (shear_axis + 180) % 360,
564            color="k",
565            linestyle="--",
566            alpha=0.5,
567            label="shear axis",
568        )
569    colors = []
570    if target_rotation_rates is not None:
571        lines = ax.plot(
572            target_initial_angles,
573            target_rotation_rates,
574            c="tab:orange",
575            label=labels[0],
576        )
577        colors.append(lines[0].get_color())
578    series = ax.scatter(
579        initial_angles,
580        rotation_rates,
581        facecolors="none",
582        edgecolors=plt.rcParams["axes.edgecolor"],
583        label=labels[1],
584    )
585    colors.append(series.get_edgecolors()[0])
586    _utils.redraw_legend(ax)
587    return fig, ax, colors
590def growth(
591    ax,
592    initial_angles,
593    fractions_diff,
594    target_initial_angles=None,
595    target_fractions_diff=None,
596    labels=("target", "computed"),
597    shear_axis=None,
598) -> tuple:
599    """Plot grain growth of grains with known, unique initial [100] angles from X.
601    If `ax` is None, a new figure and axes are created with `figure_unless`.
602    The default labels ("target", "computed") can also be overriden.
603    If `shear_axis` is not None, a dashed line will be drawn at the given x-value
604    (and its reflection around 180°).
606    Returns a tuple of the figure handle, the axes handle and the set of colors used for
607    the data series plots.
609    """
610    fig, ax = figure_unless(ax)
611    ax.set_ylabel("Grain growth rate (s⁻¹)")
612    ax.set_xlabel("Initial [100] angle (°)")
613    ax.set_xlim((0, 360))
614    ax.set_xticks(np.linspace(0, 360, 9))
615    if shear_axis is not None:
616        ax.axvline(shear_axis, color="k", linestyle="--", alpha=0.5)
617        ax.axvline(
618            (shear_axis + 180) % 360,
619            color="k",
620            linestyle="--",
621            alpha=0.5,
622            label="shear axis",
623        )
624    colors = []
625    if target_fractions_diff is not None:
626        lines = ax.plot(
627            target_initial_angles,
628            target_fractions_diff,
629            c="tab:orange",
630            label=labels[0],
631        )
632        colors.append(lines[0].get_color())
633    series = ax.scatter(
634        initial_angles,
635        fractions_diff,
636        facecolors="none",
637        edgecolors=plt.rcParams["axes.edgecolor"],
638        label=labels[1],
639    )
640    colors.append(series.get_edgecolors()[0])
641    _utils.redraw_legend(ax)
642    return fig, ax, colors
645def figure_unless(ax: plt.Axes | None) -> tuple[plt.Figure, plt.Axes]:
646    """Create figure and axes if `ax` is None, or return existing figure for `ax`.
648    If `ax` is None, a new figure is created for the axes with a few opinionated default
649    settings (grid, constrained layout, high DPI).
651    Returns a tuple containing the figure handle and the axes object.
653    """
654    fig: plt.Figure | None
655    if ax is None:
656        fig = plt.figure()
657        ax = fig.add_subplot()
658    else:
659        fig = ax.get_figure()
660    assert fig is not None
661    return fig, ax
664def figure(figscale: tuple[float, float] | None = None, **kwargs) -> plt.Figure:
665    """Create new figure with a few opinionated default settings.
667    (e.g. grid, constrained layout, high DPI).
669    The keyword argument `figscale` can be used to scale the figure width and height
670    relative to the default values by passing a tuple. Any additional keyword arguments
671    are passed to `matplotlib.pyplot.figure()`.
673    """
674    # NOTE: Opinionated defaults are set using rcParams at the top of this file.
675    _figsize = kwargs.pop("figsize", (DEFAULT_FIG_WIDTH, DEFAULT_FIG_HEIGHT))
676    if figscale is not None:
677        _figsize = (DEFAULT_FIG_WIDTH * figscale[0], DEFAULT_FIG_HEIGHT * figscale[1])
678    return plt.figure(figsize=_figsize, **kwargs)
