Skip to content

Diagnostics

diagnostics

This module provides a class to simplify computing of diagnostics typically encountered in geodynamical simulations. Users instantiate the class by providing relevant parameters and call individual class methods to compute associated diagnostics.

FunctionContext(quad_degree, func)

Hold objects that can be derived from a Firedrake Function.

This class gathers references to objects that can be pulled from a Firedrake Function object and calculates quantities based on those objects that will remain constant for the duration of a simulation. The set of objects/quantities stored are: mesh, function_space, dx and ds measures, the FacetNormal of the mesh (as the .normal attribute) and the volume of the domain.

Typical usage example:

function_contexts[F] = FunctionContext(quad_degree, F)

Parameters:

Name Type Description Default
quad_degree int

Quadrature degree to use when approximating integrands involving

required
func Function

Function

required
Source code in g-adopt/gadopt/diagnostics.py
208
209
210
def __init__(self, quad_degree: int, func: fd.Function):
    self._function = func
    self._quad_degree = quad_degree

function cached property

The function associated with the instance

mesh cached property

The mesh on which the function has been defined

function_space cached property

The function space on which the function has been defined

dx cached property

The volume integration measure defined by the mesh and quad_degree passed when creating this instance

ds cached property

The surface integration measure defined by the mesh and quad_degree passed when creating this instance

normal cached property

The facet normal of the mesh belonging to this instance

volume cached property

The volume of the mesh belonging to this instance

boundary_ids cached property

The boundary IDs of the mesh associated with this instance

check_boundary_id(boundary_id) cached

Check if a boundary id or tuple of boundary ids is valid

Source code in g-adopt/gadopt/diagnostics.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@cache
def check_boundary_id(self, boundary_id: Sequence[int | str] | int | str) -> None:
    """Check if a boundary id or tuple of boundary ids is valid"""
    # strings are Sequences, so have to handle this first otherwise this function
    # searches for 't' 'o' 'p' in ( 'top', 'bottom' )
    if isinstance(boundary_id, str):
        if boundary_id not in self.boundary_ids:
            raise KeyError("Invalid boundary ID for function")
    elif isinstance(boundary_id, Sequence):
        if not all(id in self.boundary_ids for id in boundary_id):
            raise KeyError("Invalid boundary ID for function")
    else:
        if boundary_id not in self.boundary_ids:
            raise KeyError("Invalid boundary ID for function")

surface_area(boundary_id) cached

The surface area of the mesh on the boundary belonging to boundary_id

Source code in g-adopt/gadopt/diagnostics.py
275
276
277
278
279
@cache
def surface_area(self, boundary_id: Sequence[int | str] | int | str):
    """The surface area of the mesh on the boundary belonging to boundary_id"""
    self.check_boundary_id(boundary_id)
    return get_volume(self.ds(boundary_id))

get_boundary_nodes(boundary_id) cached

Return the list of nodes on the boundary owned by this process

Creates a DirichletBC object, then uses the .nodes attribute for that object to provide a list of indices that reside on the boundary of the domain of the function associated with this FunctionContext instance. The dof_dset.size parameter of the FunctionSpace is used to exclude nodes in the halo region of the domain.

Parameters:

Name Type Description Default
boundary_id Sequence[int | str] | int | str

Integer ID of the domain boundary

required

Returns:

Type Description
list[int]

List of integers corresponding to nodes on the boundary identified by

list[int]

boundary_id

Source code in g-adopt/gadopt/diagnostics.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@cache
def get_boundary_nodes(
    self, boundary_id: Sequence[int | str] | int | str
) -> list[int]:
    """Return the list of nodes on the boundary owned by this process

    Creates a `DirichletBC` object, then uses the `.nodes` attribute for that
    object to provide a list of indices that reside on the boundary of the domain
    of the function associated with this `FunctionContext` instance. The
    `dof_dset.size` parameter of the `FunctionSpace` is used to exclude nodes in
    the halo region of the domain.

    Args:
        boundary_id: Integer ID of the domain boundary

    Returns:
        List of integers corresponding to nodes on the boundary identified by
        `boundary_id`
    """
    self.check_boundary_id(boundary_id)
    bc = fd.DirichletBC(self.function_space, 0, boundary_id)
    return [n for n in bc.nodes if n < self.function_space.dof_dset.size]

BaseDiagnostics(quad_degree, **funcs)

A base class containing useful operations for diagnostics

For each Firedrake function passed as a keyword argument in the funcs parameter, store that function as an attribute of the class accessible by its keyword, e.g.:

diag = BaseDiagnostics(quad_degree, z=z)

sets the Firedrake function z to the diag.z parameter. If the function is a MixedFunction, the subfunctions will be accessible by an index, e.g.:

diag = BaseDiagnostics(quad_degree, z=z)

sets the subfunctions of z to diag.z_0, diag.z_1, etc. A FunctionContext is created for each function. These attributes are accessed by the diag._function_contexts dict.

This class is intended to be subclassed by domain-specific diagnostic classes

Parameters:

Name Type Description Default
quad_degree int

Quadrature degree to use when approximating integrands managed by

required
**funcs Function | None

Firedrake functions to associate with this instance

{}

Initialise a BaseDiagnostics object.

Sets the quad_degree for measures used by this object and passes the remaining keyword arguments through to register_functions.

Source code in g-adopt/gadopt/diagnostics.py
331
332
333
334
335
336
337
338
339
340
def __init__(self, quad_degree: int, **funcs: fd.Function | None):
    """Initialise a BaseDiagnostics object.

    Sets the `quad_degree` for measures used by this object and passes the
    remaining keyword arguments through to `register_functions`.
    """
    self._function_contexts: dict[fd.Function | Operator, FunctionContext] = {}
    self._quad_degree = quad_degree
    self._mixed_functions: list[str] = []
    self.register_functions(**funcs)

register_functions(*, quad_degree=None, **funcs)

Register a function with this BaseDiagnostics object.

Creates a FunctionContext object for each function passed in as a keyword argument. Also creates an attribute on the instance to access the input function named for the key of the keyword argument. i.e:

> diag.register_functions(self, F=F)
> type(diag.F)
<class 'firedrake.function.Function'>
If an input function is set to None, the attribute will still be created but set to 0.0. If a mixed function is entered, each subfunction will have a FunctionContext object associated with it, and the attribute will be named with an additional number to denote the index of the subfunction i.e.:

```

diag.register_functions(self, F) type(diag.F) AttributeError: 'Requested 'F', which lives on a mixed space. Instead, access subfunctions via F_0, F_1, ..." type(diag.F_0) type(diag.F_1)

Args: quad_degree (optional): The quadrature degree for the measures to be used by this function. If None, the quad_degree passed at object instantiation time is used. Defaults to None. **funcs: key-value pairs of Firedrake functions to associate with this instance

Source code in g-adopt/gadopt/diagnostics.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def register_functions(
    self, *, quad_degree: int | None = None, **funcs: fd.Function | None
):
    """Register a function with this BaseDiagnostics object.

    Creates a `FunctionContext` object for each function passed in as a keyword
    argument. Also creates an attribute on the instance to access the input function
    named for the key of the keyword argument. i.e:

    ```
    > diag.register_functions(self, F=F)
    > type(diag.F)
    <class 'firedrake.function.Function'>
    ```
    If an input function is set to `None`, the attribute will still be created
    but set to 0.0. If a mixed function is entered, each subfunction will have
    a `FunctionContext` object associated with it, and the attribute will be named
    with an additional number to denote the index of the subfunction i.e.:

    ```
    > diag.register_functions(self, F)
    > type(diag.F)
    AttributeError: 'Requested 'F', which lives on a mixed space. Instead, access subfunctions via F_0, F_1, ..."
    > type(diag.F_0)
    <class 'firedrake.function.Function'>
    > type(diag.F_1)
    <class 'firedrake.function.Function'>

    Args:
        quad_degree (optional): The quadrature degree for the measures to be used
        by this function. If `None`, the `quad_degree` passed at object
        instantiation time is used. Defaults to None.
        **funcs: key-value pairs of Firedrake functions to associate with this
        instance
    """
    if quad_degree is None:
        quad_degree = self._quad_degree
    for name, func in funcs.items():
        # Handle optional functions in diagnostics
        if func is None:
            setattr(self, name, 0.0)
            continue
        if len(func.subfunctions) == 1:
            if not hasattr(self, name):
                setattr(self, name, func)
                self._init_single_func(quad_degree, func)
        else:
            self._mixed_functions.append(name)
            for i, subfunc in enumerate(func.subfunctions):
                if not hasattr(self, f"{name}_{i}"):
                    setattr(self, f"{name}_{i}", subfunc)
                    self._init_single_func(quad_degree, subfunc)

get_upward_component(f) cached

Get the upward (against gravity) component of a function.

Returns a UFL expression for the upward component of a function. Uses the G-ADOPT vertical_component function and caches the result such that the UFL expression only needs to be constructed once per run.

Parameters:

Name Type Description Default
f Function

Function

required

Returns:

Type Description
Operator

UFL expression for the vertical component of f

Source code in g-adopt/gadopt/diagnostics.py
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
@cache
def get_upward_component(self, f: fd.Function) -> Operator:
    """Get the upward (against gravity) component of a function.

    Returns a UFL expression for the upward component of a function. Uses the
    G-ADOPT `vertical_component` function and caches the result such that the
    UFL expression only needs to be constructed once per run.

    Args:
        f: Function

    Returns:
        UFL expression for the vertical component of `f`
    """
    self._check_present(f)
    self._check_dim_valid(f)  # Can't take upward component of a scalar function
    return vertical_component(f)

min(func_or_op, boundary_id=None, dim=None)

Calculate the minimum value of a function. See _minmax docstring for more information.

Source code in g-adopt/gadopt/diagnostics.py
632
633
634
635
636
637
638
639
640
641
642
def min(
    self,
    func_or_op: fd.Function | Operator,
    boundary_id: Sequence[int | str] | int | str | None = None,
    dim: int | None = None,
) -> float:
    """
    Calculate the minimum value of a function. See `_minmax`
    docstring for more information.
    """
    return self._minmax(func_or_op, boundary_id, dim)[0]

max(func_or_op, boundary_id=None, dim=None)

Calculate the maximum value of a function See _minmax docstring for more information.

Source code in g-adopt/gadopt/diagnostics.py
644
645
646
647
648
649
650
651
652
653
654
def max(
    self,
    func_or_op: fd.Function | Operator,
    boundary_id: Sequence[int | str] | int | str | None = None,
    dim: int | None = None,
) -> float:
    """
    Calculate the maximum value of a function See `_minmax`
    docstring for more information.
    """
    return -self._minmax(func_or_op, boundary_id, dim)[1]

integral(f, boundary_id=None)

Calculate the integral of a function over the domain associated with it

Parameters:

Name Type Description Default
f Function

Function.

required
boundary_id optional

Boundary ID. If not provided or set to None

None

Returns:

Type Description
float

Result of integration

Source code in g-adopt/gadopt/diagnostics.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
@ts_cache
def integral(
    self,
    f: fd.Function,
    boundary_id: Sequence[int | str] | int | str | None = None,
) -> float:
    """Calculate the integral of a function over the domain associated with it

    Args:
        f: Function.
        boundary_id (optional): Boundary ID. If not provided or set to `None`
        will integrate across entire domain. If provided, will integrate along
        the specified boundary only. Defaults to None.

    Returns:
        Result of integration
    """
    self._check_present(f)
    measure = self._get_measure(f, boundary_id)
    return fd.assemble(f * measure)

l1norm(f, boundary_id=None)

Calculate the L1norm of a function over the domain associated with it

Parameters:

Name Type Description Default
f Function

Function.

required
boundary_id optional

Boundary ID .If not provided or set to None,

None

Returns:

Name Type Description
float float

L1 norm

Source code in g-adopt/gadopt/diagnostics.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
@ts_cache
def l1norm(
    self, f: fd.Function, boundary_id: Sequence[int | str] | int | str | None = None
) -> float:
    """Calculate the L1norm of a function over the domain associated with it

    Args:
        f: Function.
        boundary_id (optional): Boundary ID .If not provided or set to `None`,
        will integrate across entire domain. If provided, will integrate along
        the specified boundary only. Defaults to None.

    Returns:
        float: L1 norm
    """
    self._check_present(f)
    measure = self._get_measure(f, boundary_id)
    return fd.assemble(abs(f) * measure)

l2norm(f, boundary_id=None)

Calculate the L2norm of a function over the domain associated with it

Parameters:

Name Type Description Default
f Function

Function.

required
boundary_id optional

Boundary ID. If not provided or set to None,

None

Returns:

Name Type Description
float float

L2 norm

Source code in g-adopt/gadopt/diagnostics.py
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
@ts_cache
def l2norm(
    self, f: fd.Function, boundary_id: Sequence[int | str] | int | str | None = None
) -> float:
    """Calculate the L2norm of a function over the domain associated with it

    Args:
        f: Function.
        boundary_id (optional): Boundary ID. If not provided or set to `None`,
        will integrate across entire domain. If provided, will integrate along
        the specified boundary only. Defaults to None.

    Returns:
        float: L2 norm
    """
    self._check_present(f)
    measure = self._get_measure(f, boundary_id)
    return fd.sqrt(fd.assemble(fd.dot(f, f) * measure))

rms(f)

Calculate the RMS of a function over the domain associated with it

For the purposes of this function, RMS is defined as L2norm/volume

Parameters:

Name Type Description Default
f Function

Function.

required
boundary_id optional

Boundary ID. If not provided or set to None,

required

Returns:

Name Type Description
float float

RMS

Source code in g-adopt/gadopt/diagnostics.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
@ts_cache
def rms(self, f: fd.Function) -> float:
    """Calculate the RMS of a function over the domain associated with it

    For the purposes of this function, RMS is defined as L2norm/volume

    Args:
        f: Function.
        boundary_id (optional): Boundary ID. If not provided or set to `None`,
        will integrate across entire domain. If provided, will integrate along
        the specified boundary only. Defaults to None.

    Returns:
        float: RMS
    """
    return self.l2norm(f) / fd.sqrt(self._function_contexts[f].volume)

GeodynamicalDiagnostics(z, T=None, /, bottom_id=None, top_id=None, *, quad_degree=4)

Bases: BaseDiagnostics

Typical simulation diagnostics used in geodynamical simulations.

Parameters:

Name Type Description Default
z Function

Firedrake function for mixed Stokes function space (velocity, pressure)

required
T Function | None

Firedrake function for temperature

None
bottom_id Sequence[int | str] | int | str | None

Bottom boundary identifier

None
top_id Sequence[int | str] | int | str | None

Top boundary identifier

None
quad_degree int

Degree of polynomial quadrature approximation

4
Note

All diagnostics are returned as floats.

Methods:

Name Description
u_rms

Root-mean-square velocity

u_rms_top

Root-mean-square velocity along the top boundary

Nu_top

Nusselt number at the top boundary

Nu_bottom

Nusselt number at the bottom boundary

T_avg

Average temperature in the domain

T_min

Minimum temperature in domain

T_max

Maximum temperature in domain

ux_max

Maximum velocity (first component, optionally over a given boundary)

Source code in g-adopt/gadopt/diagnostics.py
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
def __init__(
    self,
    z: fd.Function,
    T: fd.Function | None = None,
    /,
    bottom_id: Sequence[int | str] | int | str | None = None,
    top_id: Sequence[int | str] | int | str | None = None,
    *,
    quad_degree: int = 4,
):
    u, p = z.subfunctions[:2]
    super().__init__(quad_degree, u=u, p=p, T=T)

    if bottom_id:
        self.bottom_id = bottom_id
        self.ds_b = self._function_contexts[self.u].ds(bottom_id)
    if top_id:
        self.top_id = top_id
        self.ds_t = self._function_contexts[self.u].ds(top_id)

GIADiagnostics(u, /, bottom_id=None, top_id=None, *, quad_degree=4)

Bases: BaseDiagnostics

Typical simulation diagnostics used in glacial isostatic adjustment simulations.

Parameters:

Name Type Description Default
d

Firedrake function for displacement

required
bottom_id Sequence[int | str] | int | str | None

Bottom boundary identifier

None
top_id Sequence[int | str] | int | str | None

Top boundary identifier

None
quad_degree int

Degree of polynomial quadrature approximation

4
Note

All diagnostics are returned as floats.

Methods:

Name Description
u_rms

Root-mean-square displacement

u_rms_top

Root-mean-square displacement along the top boundary

ux_max

Maximum displacement (first component, optionally over a given boundary)

uv_min

Minimum vertical displacement, optionally over a given boundary

uv_max

Maximum vertical displacement, optionally over a given boundary

l2_norm_top

L2 norm of displacement on top surface

l1_norm_top

L1 norm of displacement on top surface

integrated_displacement

integral of displacement on top surface

Source code in g-adopt/gadopt/diagnostics.py
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
def __init__(
    self,
    u: fd.Function,
    /,
    bottom_id: Sequence[int | str] | int | str | None = None,
    top_id: Sequence[int | str] | int | str | None = None,
    *,
    quad_degree: int = 4,
):
    super().__init__(quad_degree, u=u)

    if bottom_id:
        self.ds_b = self._function_contexts[self.u].ds(bottom_id)
    if top_id:
        self.ds_t = self._function_contexts[self.u].ds(top_id)
        self.top_id = top_id

uv_min(boundary_id=None)

Minimum value of vertical component of velocity/displacement

Source code in g-adopt/gadopt/diagnostics.py
870
871
872
def uv_min(self, boundary_id: int | None = None) -> float:
    "Minimum value of vertical component of velocity/displacement"
    return self.min(self.get_upward_component(self.u), boundary_id)

uv_max(boundary_id=None)

Maximum value of vertical component of velocity/displacement

Source code in g-adopt/gadopt/diagnostics.py
874
875
876
def uv_max(self, boundary_id: int | None = None) -> float:
    "Maximum value of vertical component of velocity/displacement"
    return self.max(self.get_upward_component(self.u), boundary_id)

extract_functions(func_or_op) cached

Extract all Firedrake functions associated with a UFL expression.

This function recursively searches through any UFL expression for Firedrake Function objects.

Parameters:

Name Type Description Default
func_or_op Expr

The UFL expression to search through

required

Raises:

Type Description
TypeError

An object that was neither a UFL Operator or UFL Terminal

Returns:

Type Description
set[Function]

The set of found Firedrake Functions

Source code in g-adopt/gadopt/diagnostics.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@cache
def extract_functions(func_or_op: Expr) -> set[fd.Function]:
    """Extract all Firedrake functions associated with a UFL expression.

    This function recursively searches through any UFL expression for Firedrake
    `Function` objects.

    Args:
        func_or_op: The UFL expression to search through

    Raises:
        TypeError: An object that was neither a UFL Operator or UFL Terminal
        was encountered

    Returns:
        The set of found Firedrake Functions
    """
    if isinstance(func_or_op, fd.Function):
        return {func_or_op}
    elif isinstance(func_or_op, Operator):
        funcs = set()
        for f in func_or_op.ufl_operands:
            funcs |= extract_functions(f)
        return funcs
    elif isinstance(func_or_op, Terminal):
        # Some other UFL object
        return set()
    else:
        raise TypeError("Invalid type")

ts_cache(_func=None, *, input_funcs=None, make_key=partial(_make_key, typed=False))

Cache the results of a diagnostic function on a per-timestep basis

This function creates a decorator that caches the results of any diagnostic function found in a BaseDiagnostic object (or any subclass thereof) for as long as the underlying Firedrake functions remain unmodified. The modification of Firedrake functions is tracked by the dat_version attribute of the dat object which is based on the 'state' of the underlying PETSc object (see e.g. https://petsc.org/release/manualpages/Sys/PetscObjectStateGet/). Pyop2 also maintains a similar counter for non-PETSc objects.

The purpose of this decorator is to allow multiple calls to the same diagnostic function within a timestep to reuse already computed quantities (e.g. Nusselt numbers on top/bottom boundaries in energy conservation calculations) or for underlying diagnostic algorithms to calculate multiple diagnostics at once when it is efficient to do so (e.g. min/max field values - in large parallel applications small reductions are dominated by network communication time, so it costs almost no extra calculate both the minimum and maximum value of the same field simultaneously).

Parameters:

Name Type Description Default
_func optional

Used to determine if the decorator is being called with or without parentheses.

None
input_funcs optional

A string or Sequence of strings of Firedrake functions that the cached results depend on. The decorator will automatically detect Firedrake functions in its arguments, this allows custom diagnostics that do not take functions as arguments to correctly track dependent functions. Default behaviour is to track automatically detected functions only.

None
make_key optional

A function to turn args and *kwargs into a valid dictionary key. Defaults to the same method used by functools.cache with typed=False.

partial(_make_key, typed=False)

Raises:

Type Description
TypeError

The decorator has been used on an object that is not a G-ADOPT BaseDiagnostic object An attribute specified in input_funcs is not a Firedrake function

AttributeError

The BaseDiagnostic object does not have an attribute named in the input_funcs argument.

Source code in g-adopt/gadopt/diagnostics.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def ts_cache(
    _func=None,
    *,
    input_funcs: str | Sequence[str] | None = None,
    make_key=partial(_make_key, typed=False),
):
    """Cache the results of a diagnostic function on a per-timestep basis

    This function creates a decorator that caches the results of any diagnostic
    function found in a BaseDiagnostic object (or any subclass thereof) for as long
    as the underlying Firedrake functions remain unmodified. The modification of
    Firedrake functions is tracked by the `dat_version` attribute of the `dat` object
    which is based on the 'state' of the underlying PETSc object (see e.g.
    https://petsc.org/release/manualpages/Sys/PetscObjectStateGet/). Pyop2 also
    maintains a similar counter for non-PETSc objects.

    The purpose of this decorator is to allow multiple calls to the same diagnostic
    function within a timestep to reuse already computed quantities (e.g. Nusselt
    numbers on top/bottom boundaries in energy conservation calculations) or for
    underlying diagnostic algorithms to calculate multiple diagnostics at once when
    it is efficient to do so (e.g. min/max field values - in large parallel applications
    small reductions are dominated by network communication time, so it costs almost no
    extra calculate both the minimum and maximum value of the same field
    simultaneously).

    Args:
        _func (optional):
            Used to determine if the decorator is being called with or without
            parentheses.
        input_funcs (optional):
            A string or Sequence of strings of Firedrake functions that the cached
            results depend on. The decorator will automatically detect Firedrake
            functions in its arguments, this allows custom diagnostics that do not
            take functions as arguments to correctly track dependent functions.
            Default behaviour is to track automatically detected functions only.
        make_key (optional):
            A function to turn *args and **kwargs into a valid  dictionary key.
            Defaults to the same method used by functools.cache with typed=False.

    Raises:
        TypeError:
            The decorator has been used on an object that is not a G-ADOPT
            BaseDiagnostic object
            An attribute specified in `input_funcs` is not a Firedrake function

        AttributeError:
            The BaseDiagnostic object does not have an attribute named in the
            `input_funcs` argument.
    """

    def ts_cache_decorator(diag_func):
        cache = {}
        funcs = defaultdict(set)
        object_state = defaultdict(lambda: defaultdict(lambda: int(-1)))
        check_funcs = set()
        if input_funcs is not None:
            check_funcs |= set(
                (input_funcs,) if isinstance(input_funcs, str) else input_funcs
            )

        def wrapper(*args, **kwargs):
            key = make_key(args, kwargs)
            if key not in cache:
                # Do all sanity checking on the first call to the decorator
                if len(args) == 0 or not isinstance(args[0], BaseDiagnostics):
                    raise TypeError(
                        "This decorator can only be used on G-ADOPT Diagnostics functions"
                    )
                # Find all Firedrake functions in the arguments to this decorator
                for arg in args[1:]:
                    if isinstance(arg, Expr):
                        funcs[key] |= extract_functions(arg)
                for arg in kwargs.values():
                    if isinstance(arg, Expr):
                        funcs[key] |= extract_functions(arg)
                # Add any functions that were specified in the arguments to the
                # decorator factory
                for f in check_funcs:
                    if hasattr(args[0], f):
                        func = getattr(args[0], f)
                    else:
                        raise AttributeError(
                            f"No function named {f} found registered to this diagnostic object"
                        )
                    if not isinstance(func, fd.Function):
                        raise TypeError(
                            f"This diagnostic object has an attribute named {f} but it is not a Firedrake function"
                        )
                    funcs[key].add(func)
            if (
                any(object_state[key][f] != f.dat.dat_version for f in funcs[key])
                or not funcs[key]
            ):
                cache[key] = diag_func(*args, **kwargs)
                for f in funcs[key]:
                    object_state[key][f] = f.dat.dat_version
            return cache[key]

        return wrapper

    # See if we're being called as @ts_cache or @ts_cache().
    if _func is None:
        # We're called with parens.
        return ts_cache_decorator
    # We're called as @ts_cache without parens.
    return ts_cache_decorator(_func)