"""
Osses2021 Auditory Model w. Peripheral Filtering
================================================
Author:
Stefano Giacomelli - Ph.D. candidate @ DISIM dpt. - University of L'Aquila
License:
GNU General Public License v3.0 or later (GPLv3+)
This module implements the Osses et al. (2021) auditory model with realistic
peripheral filtering for fluctuation strength prediction. The model extends
the Dau1997 framework by incorporating headphone and middle ear transfer
functions, providing more accurate peripheral modeling for psychoacoustic
predictions.
The implementation is ported from the MATLAB Auditory Modeling Toolbox (AMT)
and extended with PyTorch for gradient-based optimization and GPU acceleration.
References
----------
.. [1] A. Osses Vecchi, L. E. García, and A. Kohlrausch, "Modelling the
sensation of fluctuation strength," in *Proc. Forum Acusticum 2020*,
Lyon, France, 2020, pp. 367-371.
.. [2] D. Pralong and S. Carlile, "The role of individualized headphone
calibration for the generation of high fidelity virtual auditory space,"
*J. Acoust. Soc. Am.*, vol. 100, no. 6, pp. 3785-3793, Dec. 1996.
.. [3] M. L. Jepsen, S. D. Ewert, and T. Dau, "A computational model of
human auditory signal processing and perception," *J. Acoust. Soc. Am.*,
vol. 124, no. 1, pp. 422-438, Jul. 2008.
.. [4] J. Breebaart, S. van de Par, and A. Kohlrausch, "Binaural processing
model based on contralateral inhibition. I. Model structure,"
*J. Acoust. Soc. Am.*, vol. 110, no. 2, pp. 1074-1088, Aug. 2001.
.. [5] T. Dau, D. Püschel, and A. Kohlrausch, "A quantitative model of the
'effective' signal processing in the auditory system. I. Model structure,"
*J. Acoust. Soc. Am.*, vol. 99, no. 6, pp. 3615-3622, Jun. 1996.
.. [6] P. Majdak, C. Hollomey, and R. Baumgartner, "AMT 1.x: A toolbox for
reproducible research in auditory modeling," *Acta Acust.*, vol. 6,
p. 19, 2022.
"""
from typing import Dict, Any, List
import torch
import torch.nn as nn
from torch_amt.common.ears import HeadphoneFilter, MiddleEarFilter
from torch_amt.common.filterbanks import GammatoneFilterbank
from torch_amt.common.ihc import IHCEnvelope
from torch_amt.common.adaptation import AdaptLoop
# from torch_amt.common.modulation import ModulationFilterbank
from torch_amt.common.modulation import FastModulationFilterbank
[docs]
class Osses2021(nn.Module):
r"""
Osses et al. (2021) auditory model with realistic peripheral filtering.
Implements a computational model of the auditory periphery designed for
fluctuation strength prediction and other psychoacoustic tasks. The model
extends the Dau1997 framework by incorporating headphone and middle ear
transfer functions, providing more accurate peripheral modeling suitable
for headphone-presented stimuli and perceptual predictions.
This implementation follows the MATLAB Auditory Modeling Toolbox (AMT)
osses2021 configuration and provides a differentiable, GPU-accelerated
version suitable for neural network training and optimization.
Algorithm Overview
------------------
The model implements a 6-stage auditory processing pipeline:
**Stage 0a: Headphone Filter**
Compensates for headphone and outer ear characteristics using Pralong & Carlile (1996):
.. math::
x_{hp}(t) = h_{hp}(t) * x(t)
where :math:`h_{hp}` models outer ear + headphone frequency response.
**Stage 0b: Middle Ear Filter**
Simulates middle ear transmission using Jepsen et al. (2008) model:
.. math::
x_{me}(t) = h_{me}(t) * x_{hp}(t)
Bandpass characteristic approximating middle ear transfer function.
**Stage 1: Gammatone Filterbank**
Decomposes signal into :math:`N` frequency channels with 1-ERB spacing:
.. math::
y_i(t) = g_i(t) * x_{me}(t), \\quad i=1,\\ldots,N
where :math:`g_i(t) = t^{n-1} e^{-2\\pi b t} \\cos(2\\pi f_c t + \\phi)`,
typically :math:`n=4` (4th-order).
**Stage 2: Inner Hair Cell (IHC) Envelope**
Extracts envelope via Breebaart et al. (2001) method:
.. math::
e_i(t) = |y_i(t)|^2 * h_{lp}(t)
where :math:`h_{lp}` is a lowpass filter extracting the envelope.
**Stage 3: Adaptation**
Five parallel adaptation loops with time constants :math:`\\tau_j`:
.. math::
v_i(t) = \\sum_{j=1}^5 a_j v_{i,j}(t)
where each :math:`v_{i,j}` follows:
.. math::
\\frac{dv_{i,j}}{dt} = \\frac{e_i(t) - v_{i,j}(t)}{\\tau_j}
Osses2021 preset uses :math:`\\text{limit}=5.0` (vs 10.0 in Dau1997).
**Stage 4: Modulation Filterbank**
Extracts modulation content using jepsen2008 preset:
.. math::
m_{i,k}(t) = h_{mod,k}(t) * v_i(t)
Configuration: 150 Hz lowpass, attenuation factor :math:`1/\\sqrt{2}`.
Output: List of :math:`N` tensors, tensor :math:`i` has shape :math:`(B, M_i, T)`
where :math:`M_i` varies per frequency channel.
Parameters
----------
fs : float
Sampling rate in Hz. Must match the audio sampling rate.
Common values: 44100, 48000, 32000 Hz.
flow : float, optional
Lower frequency bound for gammatone filterbank in Hz. Default: 80 Hz.
Determines the lowest auditory channel center frequency.
fhigh : float, optional
Upper frequency bound for gammatone filterbank in Hz. Default: 8000 Hz.
Determines the highest auditory channel center frequency.
phase_type : str, optional
Filter phase characteristic for peripheral filters. Default: 'minimum'.
Options:
- ``'minimum'``: Minimum-phase FIR (causal, introduces group delay)
- ``'zero'``: Zero-phase via filtfilt (non-causal, no phase distortion)
Use 'minimum' for real-time/causal processing, 'zero' for offline analysis.
learnable : bool, optional
If True, all model stages become trainable with gradient-based optimization.
Default: False (fixed parameters).
When True, enables end-to-end model training for task-specific optimization.
return_stages : bool, optional
If True, returns intermediate processing stages along with final output.
Default: False (only final modulation output).
Useful for visualization, analysis, and multi-stage training.
dtype : torch.dtype, optional
Data type for computations and parameters. Default: torch.float32.
Use torch.float64 for higher precision if needed (increases memory/computation).
**headphone_kwargs : dict, optional
Additional keyword arguments passed to :class:`HeadphoneFilter`.
Common options:
- ``filter_type`` (str): Filter characteristic. Default: 'pralong1996'.
- Other parameters accepted by HeadphoneFilter.
**middleear_kwargs : dict, optional
Additional keyword arguments passed to :class:`MiddleEarFilter`.
Common options:
- ``filter_type`` (str): Filter model. Default: 'jepsen2008'.
- ``normalize`` (bool): Normalize filter response. Default: True.
- Other parameters accepted by MiddleEarFilter.
**filterbank_kwargs : dict, optional
Additional keyword arguments passed to :class:`GammatoneFilterbank`.
Common options:
- ``n`` (int): Filter order. Default: 4 (4th-order gammatone).
- ``bandwidth_factor`` (float): Bandwidth scaling. Default: 1.0 (1-ERB).
- Other parameters accepted by GammatoneFilterbank.
**ihc_kwargs : dict, optional
Additional keyword arguments passed to :class:`IHCEnvelope`.
Common options:
- ``method`` (str): Extraction method. Default: 'breebaart2001'.
- ``cutoff`` (float): Lowpass cutoff frequency in Hz.
- Other parameters accepted by IHCEnvelope.
**adaptation_kwargs : dict, optional
Additional keyword arguments passed to :class:`AdaptLoop`.
Common options:
- ``preset`` (str): Configuration preset. Default: 'osses2021'.
- ``limit`` (float): Adaptation limit. Default: 5.0.
- ``num_loops`` (int): Number of adaptation loops. Default: 5.
- Other parameters accepted by AdaptLoop.
**modulation_kwargs : dict, optional
Additional keyword arguments passed to :class:`ModulationFilterbank`.
Common options:
- ``preset`` (str): Configuration preset. Default: 'jepsen2008'.
- ``lowpass_cutoff`` (float): Lowpass frequency in Hz. Default: 150.
- ``att_factor`` (float): Attenuation factor. Default: 1/√2.
- Other parameters accepted by ModulationFilterbank.
- Other parameters accepted by ModulationFilterbank.
Attributes
----------
fs : float
Sampling rate in Hz.
flow : float
Lower frequency bound for filterbank.
fhigh : float
Upper frequency bound for filterbank.
phase_type : str
Filter phase characteristic ('minimum' or 'zero').
learnable : bool
Whether model parameters are trainable.
return_stages : bool
Whether to return intermediate processing stages.
dtype : torch.dtype
Data type for computations.
headphone : HeadphoneFilter
Stage 0a: Headphone and outer ear filter.
middleear : MiddleEarFilter
Stage 0b: Middle ear transmission filter.
filterbank : GammatoneFilterbank
Stage 1: Gammatone auditory filterbank.
ihc : IHCEnvelope
Stage 2: Inner hair cell envelope extraction.
adaptation : AdaptLoop
Stage 3: Adaptation loops module.
modulation : ModulationFilterbank
Stage 4: Modulation filterbank.
fc : torch.Tensor
Center frequencies of auditory channels, shape (num_channels,) in Hz.
num_channels : int
Number of auditory frequency channels (depends on flow, fhigh, ERB spacing).
Input Shape
-----------
x : torch.Tensor
Audio signal with shape:
- :math:`(B, T)` - Batch of signals
- :math:`(C, T)` - Multi-channel input
- :math:`(T,)` - Single signal
where:
- :math:`B` = batch size
- :math:`C` = channels
- :math:`T` = time samples
Output Shape
------------
When ``return_stages=False`` (default):
List[torch.Tensor]
List of :math:`N` tensors (one per frequency channel), where tensor :math:`i`
has shape :math:`(B, M_i, T)`:
- :math:`N` = num_channels (number of auditory frequency channels)
- :math:`M_i` = number of modulation channels for frequency channel :math:`i`
- :math:`T` = time samples
Note: :math:`M_i` varies across frequency channels due to jepsen2008 preset.
When ``return_stages=True``:
Tuple[List[torch.Tensor], Dict[str, torch.Tensor]]
- First element: modulation output (List as above)
- Second element: dict with keys:
- ``'headphone'``: After headphone filter, shape :math:`(B, T)`
- ``'middleear'``: After middle ear filter, shape :math:`(B, T)`
- ``'filterbank'``: After gammatone, shape :math:`(B, N, T)`
- ``'ihc'``: After IHC envelope, shape :math:`(B, N, T)`
- ``'adaptation'``: After adaptation, shape :math:`(B, N, T)`
Examples
--------
**Basic usage:**
>>> import torch
>>> from torch_amt.models import Osses2021
>>>
>>> # Create model
>>> model = Osses2021(fs=44100)
>>>
>>> # Generate 1 second tone
>>> audio = torch.randn(1, 44100) * 0.01
>>>
>>> # Process
>>> output = model(audio)
>>> print(f"Number of frequency channels: {len(output)}")
Number of frequency channels: 31
>>> print(f"First channel shape (B, M, T): {output[0].shape}")
First channel shape (B, M, T): torch.Size([1, 13, 44100])
>>> print(f"Last channel shape: {output[-1].shape}")
Last channel shape: torch.Size([1, 8, 44100])
**With intermediate stages:**
>>> model_debug = Osses2021(fs=44100, return_stages=True)
>>> output, stages = model_debug(audio)
>>>
>>> print(f"Available stages: {list(stages.keys())}")
Available stages: ['headphone', 'middleear', 'filterbank', 'ihc', 'adaptation']
>>> print(f"Filterbank output: {stages['filterbank'].shape}")
Filterbank output: torch.Size([1, 31, 44100])
>>> print(f"After adaptation: {stages['adaptation'].shape}")
After adaptation: torch.Size([1, 31, 44100])
**Batch processing:**
>>> # Process multiple signals
>>> batch_audio = torch.randn(4, 44100) * 0.01
>>> output_batch = model(batch_audio)
>>> print(f"Batch output: {output_batch[0].shape}")
Batch output: torch.Size([4, 13, 44100])
**Zero-phase filters for offline analysis:**
>>> model_zero = Osses2021(fs=44100, phase_type='zero')
>>> output_zero = model_zero(audio)
>>> # No phase distortion, suitable for offline processing
**Custom frequency range:**
>>> # Narrow frequency range
>>> model_narrow = Osses2021(fs=44100, flow=500, fhigh=4000)
>>> print(f"Channels: {model_narrow.num_channels}")
Channels: 15
>>> print(f"Center frequencies: {model_narrow.fc}")
Center frequencies: tensor([ 500., 631., ..., 3175., 4000.])
**Custom submodule parameters:**
>>> # Custom adaptation limit
>>> model_custom = Osses2021(
... fs=44100,
... adaptation_kwargs={'limit': 10.0, 'num_loops': 7}
... )
>>>
>>> # Custom modulation filterbank
>>> model_mod = Osses2021(
... fs=44100,
... modulation_kwargs={'lowpass_cutoff': 200.0}
... )
**Learnable model for optimization:**
>>> model_learnable = Osses2021(fs=44100, learnable=True)
>>> n_params = sum(p.numel() for p in model_learnable.parameters())
>>> print(f"Trainable parameters: {n_params}")
Trainable parameters: 2648
>>>
>>> # Example training loop
>>> optimizer = torch.optim.Adam(model_learnable.parameters(), lr=1e-3)
>>> # ... training code ...
**Accessing center frequencies:**
>>> model = Osses2021(fs=44100)
>>> print(f"Center frequencies (Hz): {model.fc[:5]}")
Center frequencies (Hz): tensor([ 80.00, 100.79, 126.96, 159.91, 201.42])
>>> print(f"Frequency range: {model.fc[0]:.1f} - {model.fc[-1]:.1f} Hz")
Frequency range: 80.0 - 8000.0 Hz
Notes
-----
**Model Configuration:**
- **Headphone**: Pralong & Carlile (1996) outer ear + headphone compensation
- **Middle Ear**: Jepsen et al. (2008) bandpass characteristic
- **Filterbank**: Gammatone 4th-order, 1-ERB spacing
- **IHC**: Breebaart et al. (2001) envelope extraction method
- **Adaptation**: 5 loops, limit=5.0 (osses2021 preset)
- **Modulation**: jepsen2008 preset (150 Hz lowpass, att_factor=1/√2)
**Customizing Submodule Parameters:**
All submodules can be customized through dedicated kwargs dictionaries:
- Use ``headphone_kwargs`` to pass parameters to :class:`HeadphoneFilter`
- Use ``middleear_kwargs`` to pass parameters to :class:`MiddleEarFilter`
- Use ``filterbank_kwargs`` to pass parameters to :class:`GammatoneFilterbank`
- Use ``ihc_kwargs`` to pass parameters to :class:`IHCEnvelope`
- Use ``adaptation_kwargs`` to pass parameters to :class:`AdaptLoop`
- Use ``modulation_kwargs`` to pass parameters to :class:`ModulationFilterbank`
The ``learnable``, ``dtype``, and ``phase_type`` parameters are always
centralized and applied to all submodules automatically. Custom parameters
override defaults while maintaining the Osses2021 model structure.
**Phase Type Selection:**
The ``phase_type`` parameter controls peripheral filter characteristics:
- **'minimum'** (default): Causal minimum-phase FIR filters
- Introduces frequency-dependent group delay
- Suitable for real-time or causal processing
- Matches physiological phase response
- **'zero'**: Zero-phase filtering via filtfilt
- No phase distortion (symmetric impulse response)
- Non-causal (requires future samples)
- Better for offline analysis and visualization
Choose based on application: real-time → 'minimum', offline → 'zero'.
**Output Format (List of Tensors):**
Unlike other models that return a single tensor, Osses2021 returns a
**List[torch.Tensor]** because each frequency channel has a different
number of modulation channels:
.. code-block:: python
output = model(audio) # List of N tensors
for i, channel_output in enumerate(output):
print(f"Channel {i}: {channel_output.shape}")
# Channel 0: torch.Size([B, 13, T])
# Channel 1: torch.Size([B, 12, T])
# ...
This reflects the jepsen2008 modulation filterbank configuration where
higher frequency channels have fewer modulation channels. To work with
a single tensor, you can concatenate or pad as needed for your application.
**Computational Complexity:**
Processing time scales as:
.. math::
T_{compute} \\propto T \\cdot (N_{filter} + N_{filt} \\cdot N_{mod})
where :math:`T` = signal length, :math:`N_{filter}` = peripheral filter taps,
:math:`N_{filt}` = number of frequency channels (~31),
:math:`N_{mod}` = modulation channels per frequency (~8-13).
For 1 second @ 44.1 kHz: ~0.1-0.5 seconds on CPU, ~0.01-0.05 seconds on GPU.
**Memory Requirements:**
Peak memory with intermediate stages:
.. math::
Memory \\approx B \\cdot T \\cdot (N + \\sum_i M_i) \\cdot 4\\,\\text{bytes}
For batch=8, 1 second @ 44.1 kHz: ~40-80 MB.
**Applications:**
The model is particularly suited for:
- Fluctuation strength prediction (original application)
- Roughness and amplitude modulation detection
- Headphone-presented stimuli modeling
- Psychoacoustic feature extraction
- Perceptual quality assessment
- Machine learning auditory features
See Also
--------
HeadphoneFilter : Stage 0a - Headphone and outer ear filtering
MiddleEarFilter : Stage 0b - Middle ear transmission
GammatoneFilterbank : Stage 1 - Auditory filterbank
IHCEnvelope : Stage 2 - Inner hair cell envelope
AdaptLoop : Stage 3 - Adaptation loops
ModulationFilterbank : Stage 4 - Modulation analysis
References
----------
.. [1] A. Osses Vecchi, L. E. García, and A. Kohlrausch, "Modelling the
sensation of fluctuation strength," in *Proc. Forum Acusticum 2020*,
Lyon, France, 2020, pp. 367-371.
.. [2] D. Pralong and S. Carlile, "The role of individualized headphone
calibration for the generation of high fidelity virtual auditory space,"
*J. Acoust. Soc. Am.*, vol. 100, no. 6, pp. 3785-3793, Dec. 1996.
.. [3] M. L. Jepsen, S. D. Ewert, and T. Dau, "A computational model of
human auditory signal processing and perception," *J. Acoust. Soc. Am.*,
vol. 124, no. 1, pp. 422-438, Jul. 2008.
.. [4] J. Breebaart, S. van de Par, and A. Kohlrausch, "Binaural processing
model based on contralateral inhibition. I. Model structure,"
*J. Acoust. Soc. Am.*, vol. 110, no. 2, pp. 1074-1088, Aug. 2001.
.. [5] T. Dau, D. Püschel, and A. Kohlrausch, "A quantitative model of the
'effective' signal processing in the auditory system. I. Model structure,"
*J. Acoust. Soc. Am.*, vol. 99, no. 6, pp. 3615-3622, Jun. 1996.
.. [6] P. Majdak, C. Hollomey, and R. Baumgartner, "AMT 1.x: A toolbox for
reproducible research in auditory modeling," *Acta Acust.*, vol. 6,
p. 19, 2022.
"""
[docs]
def __init__(self,
fs: float,
flow: float = 80.0,
fhigh: float = 8000.0,
phase_type: str = 'minimum',
learnable: bool = False,
return_stages: bool = False,
dtype: torch.dtype = torch.float32,
headphone_kwargs: Dict[str, Any] = None,
middleear_kwargs: Dict[str, Any] = None,
filterbank_kwargs: Dict[str, Any] = None,
ihc_kwargs: Dict[str, Any] = None,
adaptation_kwargs: Dict[str, Any] = None,
modulation_kwargs: Dict[str, Any] = None):
super().__init__()
self.fs = fs
self.flow = flow
self.fhigh = fhigh
self.phase_type = phase_type
self.learnable = learnable
self.return_stages = return_stages
self.dtype = dtype
# Initialize kwargs dictionaries
headphone_kwargs = headphone_kwargs or {}
middleear_kwargs = middleear_kwargs or {}
filterbank_kwargs = filterbank_kwargs or {}
ihc_kwargs = ihc_kwargs or {}
adaptation_kwargs = adaptation_kwargs or {}
modulation_kwargs = modulation_kwargs or {}
# Stage 0a: Headphone filter (outer ear + headphone compensation)
self.headphone = HeadphoneFilter(fs=fs,
phase_type=phase_type,
learnable=learnable,
dtype=dtype,
**headphone_kwargs)
# Stage 0b: Middle ear filter (Jepsen 2008 variant)
middleear_defaults = {'filter_type': 'jepsen2008'}
middleear_params = {**middleear_defaults, **middleear_kwargs}
self.middleear = MiddleEarFilter(fs=fs,
phase_type=phase_type,
learnable=learnable,
dtype=dtype,
**middleear_params)
# Stage 1: Gammatone filterbank
filterbank_defaults = {'n': 4}
filterbank_params = {**filterbank_defaults, **filterbank_kwargs}
self.filterbank = GammatoneFilterbank(fc=(flow, fhigh),
fs=fs,
learnable=learnable,
dtype=dtype,
**filterbank_params)
self.fc = self.filterbank.fc
self.num_channels = self.filterbank.num_channels
# Stage 2: Inner hair cell envelope extraction (Breebaart 2001 method)
ihc_defaults = {'method': 'breebaart2001'}
ihc_params = {**ihc_defaults, **ihc_kwargs}
self.ihc = IHCEnvelope(fs=fs, learnable=learnable, dtype=dtype, **ihc_params)
# Stage 3: Adaptation loops (osses2021 preset: limit=5.0)
adaptation_defaults = {'preset': 'osses2021'}
adaptation_params = {**adaptation_defaults, **adaptation_kwargs}
self.adaptation = AdaptLoop(fs=fs, learnable=learnable, dtype=dtype, **adaptation_params)
# Stage 4: Modulation filterbank (jepsen2008 preset)
modulation_defaults = {'preset': 'jepsen2008'}
modulation_params = {**modulation_defaults, **modulation_kwargs}
# self.modulation = ModulationFilterbank(fs=fs, fc=self.fc, learnable=learnable, dtype=dtype, **modulation_params)
self.modulation = FastModulationFilterbank(fs=fs, fc=self.fc, learnable=learnable, dtype=dtype, **modulation_params)
[docs]
def forward(self, x: torch.Tensor) -> List[torch.Tensor] | tuple[List[torch.Tensor], Dict[str, Any]]:
"""
Process audio through the Osses2021 model.
Parameters
----------
x : torch.Tensor
Input audio signal. Shape: (B, T), (C, T), or (T,).
Returns
-------
List[torch.Tensor] or tuple
If return_stages=False:
List of tensors (one per frequency channel), each shape (B, M, T).
If return_stages=True:
Tuple of (output, stages) where stages is a dict with intermediate outputs.
"""
stages = {} if self.return_stages else None
# Normalize input shape to (B, T)
original_shape = x.shape
if x.ndim == 1:
x = x.unsqueeze(0) # [T] -> [1, T]
# Stage 0a: Headphone filter
# Input: [B, T], Output: [B, T]
x = self.headphone(x)
if self.return_stages:
stages['headphone'] = x.clone()
# Stage 0b: Middle ear filter
# Input: [B, T], Output: [B, T]
x = self.middleear(x)
if self.return_stages:
stages['middleear'] = x.clone()
# Stage 1: Gammatone filterbank
# Input: [B, T], Output: [B, F, T]
x = self.filterbank(x)
if self.return_stages:
stages['filterbank'] = x.clone()
# Stage 2: IHC envelope extraction
# Input: [B, F, T], Output: [B, F, T]
x = self.ihc(x)
if self.return_stages:
stages['ihc'] = x.clone()
# Stage 3: Adaptation
# Input: [B, F, T], Output: [B, F, T]
x = self.adaptation(x)
if self.return_stages:
stages['adaptation'] = x.clone()
# Stage 4: Modulation filterbank
# Input: [B, F, T], Output: List of [B, M_i, T]
output = self.modulation(x)
# If original input was 1D, squeeze batch dimension from all outputs
if len(original_shape) == 1:
output = [out.squeeze(0) if out.ndim == 3 else out for out in output]
if self.return_stages:
return output, stages
else:
return output
[docs]
def distribute_gradients(self):
"""
Distribute gradients to grouped filter coefficients in FastModulationFilterbank.
Call this method after ``loss.backward()`` to ensure all filter coefficients
in the modulation filterbank receive gradient updates. This is necessary when
using ``FastModulationFilterbank`` with ``learnable=True``, as filters are
grouped for efficiency and gradients need to be shared across group members.
Notes
-----
This method should be called in the training loop:
>>> model = Osses2021(fs=44100, learnable=True)
>>> output = model(input_signal)
>>> loss = criterion(output, target)
>>> loss.backward()
>>> model.distribute_gradients() # ← Important for FastModulationFilterbank!
>>> optimizer.step()
If the modulation filterbank doesn't have a ``distribute_gradients`` method
(e.g., using standard ModulationFilterbank), this is a no-op.
"""
if hasattr(self.modulation, 'distribute_gradients'):
self.modulation.distribute_gradients()