Source code for ufl.algorithms.remove_component_tensors

"""Remove component tensors.

This module contains classes and functions to remove component tensors.
"""
# Copyright (C) 2025 Pablo Brubeck
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier:    LGPL-3.0-or-later

from collections import defaultdict

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.classes import ComponentTensor, Index, MultiIndex, Zero
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction
from ufl.index_combination_utils import unique_sorted_indices


[docs] class IndexReplacer(MultiFunction): """Replace Indices.""" def __init__(self, fimap: dict): """Initialise. Args: fimap: map for index replacements. """ MultiFunction.__init__(self) self.fimap = fimap expr = MultiFunction.reuse_if_untouched
[docs] def zero(self, o): """Handle Zero.""" indices = tuple(map(Index, o.ufl_free_indices)) if not any(i in self.fimap for i in indices): # Reuse if untouched return o fi = [] for i, d in zip(indices, o.ufl_index_dimensions): j = self.fimap.get(i, i) if isinstance(j, Index): fi.append((j.count(), d)) fi = unique_sorted_indices(sorted(fi)) free_indices, index_dimensions = zip(*fi) return Zero( shape=o.ufl_shape, free_indices=free_indices, index_dimensions=index_dimensions, )
[docs] def multi_index(self, o): """Handle MultiIndex.""" if not any(i in self.fimap for i in o): # Reuse if untouched return o indices = tuple(self.fimap.get(i, i) for i in o) return MultiIndex(indices)
[docs] class IndexRemover(MultiFunction): """Remove Indexed.""" def __init__(self): """Initialise.""" MultiFunction.__init__(self) self.rules = {} # caches for reuse in the dispatched transformers self.vcaches = defaultdict(dict) self.rcaches = defaultdict(dict) expr = MultiFunction.reuse_if_untouched
[docs] def indexed(self, o, o1, i1): """Simplify Indexed.""" if isinstance(o1, ComponentTensor): # Simplify Indexed ComponentTensor o2, i2 = o1.ufl_operands # Replace outer indices rkey = (i2, i1) rule = self.rules.get(rkey) if rule is None: # NOTE: Replace with `fimap = dict(zip(i2, i1, strict=True))` when # Python>=3.10 assert len(i2) == len(i1) fimap = dict(zip(i2, i1)) rule = IndexReplacer(fimap) self.rules[rkey] = rule key = (IndexReplacer, *rkey) return map_expr_dag(rule, o2, vcache=self.vcaches[key], rcache=self.rcaches[key]) elif o.ufl_operands[0] is o1: # Reuse if untouched return o else: return o._ufl_expr_reconstruct_(o1, i1)
[docs] def remove_component_tensors(o): """Remove component tensors.""" rule = IndexRemover() return map_integrand_dags(rule, o)