  1"""> PyDRex: Entry points and argument handling for command line tools.
  3All CLI handlers should be registered in the `CLI_HANDLERS` namedtuple,
  4which ensures that they will be installed as executable scripts alongside the package.
  8import argparse
  9import os
 10from collections import namedtuple
 11from zipfile import ZipFile
 13from pydrex import exceptions as _err
 14from pydrex import io as _io
 15from pydrex import logger as _log
 16from pydrex import minerals as _minerals
 17from pydrex import stats as _stats
 18from pydrex import visualisation as _vis
 21class CliTool:
 22    """Base class for CLI tools defining the required interface."""
 24    def __call__(self):
 25        return NotImplementedError
 27    def _get_args(self) -> argparse.Namespace | type[NotImplementedError]:
 28        return NotImplementedError
 31class MeshGenerator(CliTool):
 32    """PyDRex script to generate various simple meshes.
 34    Only rectangular (2D) meshes are currently supported. The RESOLUTION must be a comma
 35    delimited set of directives of the form `<LOC>:<RES>` where `<LOC>` is a location
 36    specifier, i.e. either "G" (global) or a compas direction like "N", "S", "NE", etc.,
 37    and `<RES>` is a floating point value to be set as the resolution at that location.
 39    """
 41    def __call__(self):
 42        try:  # This one is dangerous, especially in CI.
 43            from pydrex import mesh as _mesh
 44        except ImportError:
 45            raise _err.MissingDependencyError(
 46                "missing optional meshing dependencies."
 47                + " Have you installed the package with 'pip install pydrex[mesh]'?"
 48            )
 50        args = self._get_args()
 52        if is None:
 53            center = (0, 0)
 54        else:
 55            center = [float(s) for s in",")]
 56            assert len(center) == 2
 58        if args.custom_points is not None:
 59            _custom_points = [
 60                [*map(float, point.split(":"))]
 61                for point in args.custom_points.split(",")
 62            ]
 63            # Extract the insertion indices and parse into tuple.
 64            custom_indices = [int(point[0]) for point in _custom_points]
 65            custom_points = (custom_indices, [point[1:] for point in _custom_points])
 67        if args.kind == "rectangle":
 68            width, height = map(float, args.size.split(","))
 69            _loc_map = {
 70                "G": "global",
 71                "N": "north",
 72                "S": "south",
 73                "E": "east",
 74                "W": "west",
 75                "NE": "north-east",
 76                "NW": "north-west",
 77                "SE": "south-east",
 78                "SW": "south-west",
 79            }
 80            try:
 81                resolution = {
 82                    _loc_map[k]: float(v)
 83                    for k, v in map(lambda s: s.split(":"), args.resolution.split(","))
 84                }
 85            except KeyError:
 86                raise KeyError(
 87                    "invalid or unsupported location specified in resolution directive"
 88                ) from None
 89            except ValueError:
 90                raise ValueError(
 91                    "invalid resolution value. The format should be '<LOC1>:<RES1>,<LOC2>:<RES2>,...'"
 92                ) from None
 93            _mesh.rectangle(
 94                args.output[:-4],
 95                (args.ref_axes[0], args.ref_axes[1]),
 96                center,
 97                width,
 98                height,
 99                resolution,
100                custom_constraints=custom_points,
101            )
103    def _get_args(self) -> argparse.Namespace:
104        assert self.__doc__ is not None, f"missing docstring for {self}"
105        description, epilog = self.__doc__.split(os.linesep + os.linesep, 1)
106        parser = argparse.ArgumentParser(description=description, epilog=epilog)
107        parser.add_argument("size", help="width,height[,depth] of the mesh")
108        parser.add_argument(
109            "-r",
110            "--resolution",
111            help="resolution for the mesh (edge length hint(s) for gmsh)",
112            required=True,
113        )
114        parser.add_argument("output", help="output file (.msh)")
115        parser.add_argument(
116            "-c",
117            "--center",
118            help="center of the mesh as 2 or 3 comma-separated coordinates. default: (0, 0[, 0])",
119            default=None,
120        )
121        parser.add_argument(
122            "-a",
123            "--ref-axes",
124            help=(
125                "two letters from {'x', 'y', 'z'} that specify"
126                + " the horizontal and vertical axes of the mesh"
127            ),
128            default="xz",
129        )
130        parser.add_argument(
131            "-k", "--kind", help="kind of mesh, e.g. 'rectangle'", default="rectangle"
132        )
133        parser.add_argument(
134            "-p",
135            "--custom-points",
136            help="comma-separated custom point constraints (in the format index:x1:x2[:x3]:resolution)",
137            default=None,
138        )
139        return parser.parse_args()
142class H5partExtractor(CliTool):
143    """PyDRex script to extract raw CPO data from Fluidity .h5part files.
145    Fluidity saves data stored on model `particles` to an `.h5part` file.
146    This script converts that file to canonical serialisation formats:
147    - a `.npz` file containing the raw CPO orientations and (surrogate) grain sizes
148    - an `.scsv` file containing the pathline positions and accumulated strain
150    It is assumed that CPO data is stored in keys called 'CPO_<N>' in the .h5part
151    data, where `<N>` is an integer in the range 1—`n_grains`. The accumulated strain is
152    read from the attribute `CPO_<S>` where S=`ngrains`+1. Particle positions are read
153    from the attributes `x`, `y`, and `z`.
155    At the moment, dynamic changes in fabric or phase are not supported.
157    """
159    def __call__(self):
160        args = self._get_args()
161        _io.extract_h5part(
162            args.input, args.phase, args.fabric, args.ngrains, args.output
163        )
165    def _get_args(self) -> argparse.Namespace:
166        assert self.__doc__ is not None, f"missing docstring for {self}"
167        description, epilog = self.__doc__.split(os.linesep + os.linesep, 1)
168        parser = argparse.ArgumentParser(description=description, epilog=epilog)
169        parser.add_argument("input", help="input file (.h5part)")
170        parser.add_argument(
171            "-p",
172            "--phase",
173            help="type of `pydrex.MineralPhase` (as an ordinal number); 0 by default",
174            default=0,
175        )
176        parser.add_argument(
177            "-f",
178            "--fabric",
179            type=int,
180            help="type of `pydrex.MineralFabric` (as an ordinal number); 0 by default",
181            default=0,
182        )
183        parser.add_argument(
184            "-n",
185            "--ngrains",
186            help="number of grains used in the Fluidity simulation",
187            type=int,
188            required=True,
189        )
190        parser.add_argument(
191            "-o",
192            "--output",
193            help="filename for the output NPZ file (stem also used for the .scsv)",
194            required=True,
195        )
196        return parser.parse_args()
199class NPZFileInspector(CliTool):
200    """PyDRex script to show information about serialized CPO data.
202    Lists the keys that should be used for the `postfix` in `pydrex.Mineral.load` and
203    `pydrex.Mineral.from_file`.
205    """
207    def __call__(self):
208        args = self._get_args()
209        with ZipFile(args.input) as npz:
210            names = npz.namelist()
211            print("NPZ file with keys:")
212            for name in names:
213                if not (
214                    name.startswith("meta")
215                    or name.startswith("fractions")
216                    or name.startswith("orientations")
217                ):
218                    _log.warning(f"found unknown NPZ key '{name}' in '{args.input}'")
219                print(f" - {name}")
221    def _get_args(self) -> argparse.Namespace:
222        assert self.__doc__ is not None, f"missing docstring for {self}"
223        description, epilog = self.__doc__.split(os.linesep + os.linesep, 1)
224        parser = argparse.ArgumentParser(description=description, epilog=epilog)
225        parser.add_argument("input", help="input file (.npz)")
226        return parser.parse_args()
229class PoleFigureVisualiser(CliTool):
230    """PyDRex script to plot pole figures of serialized CPO data.
232    Produces [100], [010] and [001] pole figures for serialized `pydrex.Mineral`s.
233    If the range of indices is not specified,
234    a maximum of 25 of each pole figure will be produced by default.
236    """
238    def __call__(self):
239        try:
240            args = self._get_args()
241            if args.range is None:
242                i_range = None
243            else:
244                start, stop_ex, step = (int(s) for s in args.range.split(":"))
245                # Make command line start:stop:step stop-inclusive, it's more intuitive.
246                i_range = range(start, stop_ex + step, step)
248            density_kwargs = {"kernel": args.kernel}
249            if args.smoothing is not None:
250                density_kwargs["σ"] = args.smoothing
252            mineral = _minerals.Mineral.from_file(args.input, postfix=args.postfix)
253            if i_range is None:
254                i_range = range(0, len(mineral.orientations))
255                if len(i_range) > 25:
256                    _log.warning(
257                        "truncating to 25 timesteps (out of %s total)", len(i_range)
258                    )
259                    i_range = range(0, 25)
261            orientations_resampled, _ = _stats.resample_orientations(
262                mineral.orientations[i_range.start : i_range.stop : i_range.step],
263                mineral.fractions[i_range.start : i_range.stop : i_range.step],
264            )
265            if args.scsv is None:
266                strains = None
267            else:
268                strains = _io.read_scsv(args.scsv).strain[
269                    i_range.start : i_range.stop : i_range.step
270                ]
271            _vis.polefigures(
272                orientations_resampled,
273                ref_axes=args.ref_axes,
274                i_range=i_range,
275                density=args.density,
276                savefile=args.out,
277                strains=strains,
278                **density_kwargs,
279            )
280        except (argparse.ArgumentError, ValueError, _err.Error) as e:
281            _log.error(str(e))
283    def _get_args(self) -> argparse.Namespace:
284        assert self.__doc__ is not None, f"missing docstring for {self}"
285        description, epilog = self.__doc__.split(os.linesep + os.linesep, 1)
286        parser = argparse.ArgumentParser(description=description, epilog=epilog)
287        parser.add_argument("input", help="input file (.npz)")
288        parser.add_argument(
289            "-r",
290            "--range",
291            help="range of strain indices to be plotted, in the format start:stop:step",
292            default=None,
293        )
294        parser.add_argument(
295            "-f",
296            "--scsv",
297            help=(
298                "path to SCSV file with a column named 'strain'"
299                + " that lists shear strain percentages for each strain index"
300            ),
301            default=None,
302        )
303        parser.add_argument(
304            "-p",
305            "--postfix",
306            help=(
307                "postfix of the mineral to load,"
308                + " required if the input file contains data for multiple minerals"
309            ),
310            default=None,
311        )
312        parser.add_argument(
313            "-d",
314            "--density",
315            help="toggle contouring of pole figures using point density estimation",
316            default=False,
317            action="store_true",
318        )
319        parser.add_argument(
320            "-k",
321            "--kernel",
322            help=(
323                "kernel function for point density estimation, one of:"
324                + f" {list(_stats.SPHERICAL_COUNTING_KERNELS.keys())}"
325            ),
326            default="linear_inverse_kamb",
327        )
328        parser.add_argument(
329            "-s",
330            "--smoothing",
331            help="smoothing parameter for Kamb type density estimation kernels",
332            default=None,
333            type=float,
334            metavar="σ",
335        )
336        parser.add_argument(
337            "-a",
338            "--ref-axes",
339            help=(
340                "two letters from {'x', 'y', 'z'} that specify"
341                + " the horizontal and vertical axes of the pole figures"
342            ),
343            default="xz",
344        )
345        parser.add_argument(
346            "-o",
347            "--out",
348            help="name of the output file, with either .png or .pdf extension",
349            default="polefigures.png",
350        )
351        return parser.parse_args()
354# These are not the final names of the executables (those are set in pyproject.toml).
355_CLI_HANDLERS = namedtuple(
356    "_CLI_HANDLERS",
357    (
358        "pole_figure_visualiser",
359        "npz_file_inspector",
360        "mesh_generator",
361        "h5part_extractor",
362    ),
365    pole_figure_visualiser=PoleFigureVisualiser(),
366    npz_file_inspector=NPZFileInspector(),
367    mesh_generator=MeshGenerator(),
368    h5part_extractor=H5partExtractor(),
