Source code for torch_amt.common.ihc

"""
Inner Hair Cell Models
======================

Author:
    Stefano Giacomelli - Ph.D. candidate @ DISIM dpt. - University of L'Aquila

License:
    GNU General Public License v3.0 or later (GPLv3+)

This module implements inner hair cell (IHC) transduction models that convert 
basilar membrane motion to neural signals. Two main approaches are provided:

1. **IHCEnvelope**: Classical envelope extraction via half-wave rectification 
   and low-pass filtering. Multiple preset configurations match published models 
   from the auditory literature (Dau1996, Breebaart2001, Lindemann1986, King2019).

2. **IHCPaulick2024**: Physiologically detailed IHC transduction for the CASP 
   model, including mechano-electrical transduction (MET) channel dynamics and 
   electrical circuit modeling.

The implementations follow the Auditory Modeling Toolbox (AMT) for MATLAB/Octave, 
ensuring compatibility with established computational auditory models.

References
----------
.. [1] P. Majdak, C. Hollomey, and R. Baumgartner, "AMT 1.x: A toolbox for 
       reproducible research in auditory modeling," *Acta Acustica*, vol. 6, 
       p. 19, 2022, doi: 10.1051/aacus/2022011.

.. [2] P. Søndergaard and P. Majdak, "The Auditory Modeling Toolbox," in 
       *The Technology of Binaural Listening*, J. Blauert, Ed. 
       Berlin-Heidelberg, Germany: Springer, 2013, pp. 33-56, 
       doi: 10.1007/978-3-642-37762-4_2.

.. [3] P. Majdak et al., "The Auditory Modeling Toolbox 1.x Full Packages," 
       SourceForge, 2022. [Online]. Available: 
       https://sourceforge.net/projects/amtoolbox/files/AMT%201.x/
"""

from typing import Optional

import torch
import torch.nn as nn
from scipy.signal import butter

# ------------------------------------------------- Utilities ------------------------------------------------

@torch.jit.script
def _precharge_circuit_jit(batch_size: int,
                           num_channels: int,
                           fs: float,
                           precharge_duration: float,
                           V_rest: torch.Tensor,
                           G_precharge: torch.Tensor,
                           EP: torch.Tensor,
                           Gkf: torch.Tensor,
                           Gks: torch.Tensor,
                           Ekf: torch.Tensor,
                           Eks: torch.Tensor,
                           Cm: torch.Tensor,
                           dtype: torch.dtype,
                           device: torch.device) -> torch.Tensor:
    """
    JIT-compiled pre-charge circuit solver.
    
    Simulates 50 ms of activity to reach steady-state resting potential.
    2-3x faster than non-JIT version.
    
    Parameters
    ----------
    batch_size : int
        Batch size B.
    
    num_channels : int
        Number of frequency channels F.
    
    fs : float
        Sampling rate in Hz.
    
    precharge_duration : float
        Pre-charge duration in seconds (typically 0.05s).
    
    V_rest : torch.Tensor
        Resting potential in Volts (scalar).
    
    G_precharge : torch.Tensor
        Pre-charge conductance in Siemens (scalar).
    
    EP : torch.Tensor
        Endocochlear potential in Volts (scalar).
    
    Gkf : torch.Tensor
        Fast K+ conductance in Siemens (scalar).
    
    Gks : torch.Tensor
        Slow K+ conductance in Siemens (scalar).
    
    Ekf : torch.Tensor
        Fast K+ reversal potential in Volts (scalar).
    
    Eks : torch.Tensor
        Slow K+ reversal potential in Volts (scalar).
    
    Cm : torch.Tensor
        Membrane capacitance in Farads (scalar).
    
    dtype : torch.dtype
        Data type.
    
    device : torch.device
        Device.
    
    Returns
    -------
    torch.Tensor
        Pre-charged voltage in Volts, shape (B, F).
    """
    Ts = 1.0 / fs
    n_samples = int(fs * precharge_duration)
    
    # Initialize at resting potential
    V_now = torch.full((batch_size, num_channels), 
                      V_rest.item(), 
                      dtype=dtype, 
                      device=device)
    
    # Evolve to steady state
    for _ in range(n_samples):
        Imet = G_precharge * (V_now - EP)
        Ik = Gkf * (V_now - Ekf)
        Is = Gks * (V_now - Eks)
        V_now = V_now - (Imet + Ik + Is) * Ts / Cm
    
    return V_now


@torch.jit.script
def _solve_circuit_ode_jit(G: torch.Tensor,
                           V_precharge: torch.Tensor,
                           EP: torch.Tensor,
                           Gkf: torch.Tensor,
                           Gks: torch.Tensor,
                           Ekf: torch.Tensor,
                           Eks: torch.Tensor,
                           Cm: torch.Tensor,
                           fs: float) -> torch.Tensor:
    """
    JIT-compiled ODE solver for IHC electrical circuit.
    
    Integrates membrane potential using Forward Euler method.
    2-3x faster than non-JIT version.
    
    Parameters
    ----------
    G : torch.Tensor
        MET channel conductance in Siemens, shape (B, F, T).
    
    V_precharge : torch.Tensor
        Pre-charged voltage in Volts, shape (B, F).
    
    EP : torch.Tensor
        Endocochlear potential in Volts (scalar).
    
    Gkf : torch.Tensor
        Fast K+ conductance in Siemens (scalar).
    
    Gks : torch.Tensor
        Slow K+ conductance in Siemens (scalar).
    
    Ekf : torch.Tensor
        Fast K+ reversal potential in Volts (scalar).
    
    Eks : torch.Tensor
        Slow K+ reversal potential in Volts (scalar).
    
    Cm : torch.Tensor
        Membrane capacitance in Farads (scalar).
    
    fs : float
        Sampling rate in Hz.
    
    Returns
    -------
    torch.Tensor
        Receptor potential in Volts (relative to pre-charge), shape (B, F, T).
    """
    batch_size, num_channels, n_samples = G.shape
    Ts = 1.0 / fs
    
    # Initialize output
    V = torch.zeros_like(G)
    V_now = V_precharge.clone()  # [batch, channels]
    
    # Forward Euler integration
    for t in range(n_samples):
        # Compute currents
        Imet = -G[:, :, t] * (V_now - EP)
        Ik = -Gkf * (V_now - Ekf)
        Is = -Gks * (V_now - Eks)
        
        # Update voltage
        V_now = V_now + (Imet + Ik + Is) * Ts / Cm
        V[:, :, t] = V_now
    
    # Return voltage relative to pre-charge level
    return V - V_precharge.unsqueeze(-1)

# ------------------------------------------- Inner Hair Cell Envelope ---------------------------------------

[docs] class IHCEnvelope(nn.Module): r""" Inner hair cell envelope extraction. Models the signal transduction of inner hair cells (IHC) by extracting the envelope of the basilar membrane motion through half-wave rectification followed by low-pass filtering. Algorithm Overview ------------------ The IHC envelope extraction consists of two main stages: 1. **Half-wave rectification**: Models the directional sensitivity of stereocilia deflection. Only positive deflections generate response: .. math:: x_{\text{rect}}(t) = \max(x(t), 0) 2. **Butterworth low-pass filtering**: Models the loss of phase-locking at high frequencies in auditory nerve fibers: .. math:: y(t) = \text{IIR}(x_{\text{rect}}(t), b, a) where :math:`b, a` are Butterworth filter coefficients. For the ``breebaart2001`` method, the filter is applied iteratively 5 times, reducing the effective cutoff frequency from 2000 Hz to approximately 770 Hz. Parameters ---------- fs : float Sampling rate in Hz. cutoff : float, optional Cutoff frequency for the low-pass filter in Hz. If ``None``, uses method-specific default. Default: ``None``. order : int, optional Order of the Butterworth filter. Typically 1 for most methods. Default: 1. method : {'dau1996', 'breebaart2001', 'king2019', 'lindemann'}, optional Preset configuration for the extraction method: * ``'dau1996'``: 1st order Butterworth at 1000 Hz (1 iteration). * ``'breebaart2001'``: 1st order Butterworth at 2000 Hz (5 iterations, effective ~770 Hz cutoff). * ``'king2019'``: 1st order Butterworth at 1500 Hz (1 iteration). * ``'lindemann'``: 1st order Butterworth at 800 Hz (1 iteration). Default: ``'dau1996'``. learnable : bool, optional If ``True``, filter coefficients ``b`` and ``a`` become learnable parameters. If ``False``, they are registered as buffers. Default: ``False``. dtype : torch.dtype, optional Data type for internal computations. Default: ``torch.float32``. Attributes ---------- fs : float Sampling rate in Hz. method : str Preset method name. cutoff : float Cutoff frequency in Hz (after preset selection). order : int Filter order (always 1 for implemented presets). iterations : int Number of times the filter is applied (5 for breebaart2001, 1 otherwise). b : torch.Tensor or nn.Parameter Numerator coefficients of IIR filter. Shape: ``(order+1,)``. a : torch.Tensor or nn.Parameter Denominator coefficients of IIR filter. Shape: ``(order+1,)``. learnable : bool Whether filter coefficients are learnable. dtype : torch.dtype Data type for computations. Shape ----- - Input: :math:`(B, F, T)` or :math:`(F, T)` where :math:`B` is batch size, :math:`F` is frequency channels, :math:`T` is time samples. - Output: Same shape as input. Notes ----- **Preset Differences** +-----------------+-------------+-------+------------+-------------------+ | Method | Cutoff (Hz) | Order | Iterations | Effective Cutoff | +=================+=============+=======+============+===================+ | dau1996 | 1000 | 1 | 1 | 1000 Hz | +-----------------+-------------+-------+------------+-------------------+ | breebaart2001 | 2000 | 1 | 5 | ~770 Hz | +-----------------+-------------+-------+------------+-------------------+ | king2019 | 1500 | 1 | 1 | 1500 Hz | +-----------------+-------------+-------+------------+-------------------+ | lindemann | 800 | 1 | 1 | 800 Hz | +-----------------+-------------+-------+------------+-------------------+ **Successive Filtering (breebaart2001)** The ``breebaart2001`` method applies a 2000 Hz 1st-order lowpass filter 5 times in series. This is equivalent to a higher-order filter with reduced cutoff. The effective -3dB cutoff frequency is approximately 770 Hz, as documented in Breebaart's thesis (2001, p. 94). See Also -------- IHCPaulick2024 : Physiologically detailed IHC transduction (CASP model) GammatoneFilterbank : Gammatone peripheral filtering (typical input source) DRNLFilterbank : Dual-resonance non-linear filterbank (alternative input) AdaptLoop : Auditory nerve adaptation (downstream processing) modfilterbank : Modulation filterbank (downstream processing) headphonefilter : Free-field to headphone transfer function middleearfilter : Middle ear transfer function Examples -------- **Basic usage with default preset:** >>> import torch >>> from torch_amt.common.ihc import IHCEnvelope >>> ihc = IHCEnvelope(fs=44100, method='dau1996') >>> x = torch.randn(1, 31, 44100) # Gammatone filterbank output >>> y = ihc(x) >>> print(y.shape) torch.Size([1, 31, 44100]) **Comparing different presets:** >>> # Dau 1996 (1000 Hz) >>> ihc_dau = IHCEnvelope(fs=16000, method='dau1996') >>> # Breebaart 2001 (effective ~770 Hz) >>> ihc_bree = IHCEnvelope(fs=16000, method='breebaart2001') >>> # Lindemann 1986 (800 Hz) >>> ihc_lind = IHCEnvelope(fs=16000, method='lindemann') >>> >>> x = torch.randn(2, 31, 16000) >>> y_dau = ihc_dau(x) >>> y_bree = ihc_bree(x) >>> y_lind = ihc_lind(x) **Learnable IHC parameters for model training:** >>> ihc_learn = IHCEnvelope(fs=44100, method='dau1996', learnable=True) >>> print(f"Learnable params: {sum(p.numel() for p in ihc_learn.parameters())}") Learnable params: 4 >>> # Coefficients b, a can be optimized during training >>> optimizer = torch.optim.Adam(ihc_learn.parameters(), lr=1e-3) References ---------- .. [1] T. Dau, D. Püschel, and A. Kohlrausch, "A quantitative model of the 'effective' signal processing in the auditory system. I. Model structure," *J. Acoust. Soc. Am.*, vol. 99, no. 6, pp. 3615-3622, 1996. .. [2] J. Breebaart, S. van de Par, and A. Kohlrausch, "Binaural processing model based on contralateral inhibition. I. Model structure," *J. Acoust. Soc. Am.*, vol. 110, no. 2, pp. 1074-1088, 2001. .. [3] W. Lindemann, "Extension of a binaural cross-correlation model by contralateral inhibition. I. Simulation of lateralization for stationary signals," *J. Acoust. Soc. Am.*, vol. 80, no. 6, pp. 1608-1622, 1986. .. [4] A. J. King, J. W. H. Schnupp, and A. R. D. Thornton, "Localization of sounds in the median sagittal plane with and without spectral cues," *J. Acoust. Soc. Am.*, vol. 145, no. 3, pp. 1437-1447, 2019. """
[docs] def __init__(self, fs: float, cutoff: Optional[float] = None, order: int = 1, method: str = 'dau1996', learnable: bool = False, dtype: torch.dtype = torch.float32): super().__init__() self.fs = fs self.method = method self.dtype = dtype self.learnable = learnable # Set cutoff and order based on method if method == 'dau1996': self.cutoff = cutoff if cutoff is not None else 1000.0 self.order = 1 self.iterations = 1 elif method == 'breebaart2001': # Successive filtering: 2000 Hz cutoff applied 5 times -> effective 770 Hz self.cutoff = cutoff if cutoff is not None else 2000.0 self.order = 1 self.iterations = 5 elif method == 'king2019': self.cutoff = cutoff if cutoff is not None else 1500.0 self.order = 1 self.iterations = 1 elif method == 'lindemann': self.cutoff = cutoff if cutoff is not None else 800.0 self.order = 1 self.iterations = 1 else: raise ValueError(f"Unknown method: {method}") # Design Butterworth low-pass filter b_init, a_init = self._design_butterworth_lowpass(self.cutoff, self.fs, self.order) if learnable: self.b = nn.Parameter(b_init) self.a = nn.Parameter(a_init) else: self.register_buffer('b', b_init) self.register_buffer('a', a_init)
def _design_butterworth_lowpass(self, cutoff: float, fs: float, order: int) -> tuple[torch.Tensor, torch.Tensor]: r""" Design Butterworth low-pass filter coefficients. Uses ``scipy.signal.butter`` to design the digital filter and converts the coefficients to PyTorch tensors. Parameters ---------- cutoff : float Cutoff frequency in Hz (-3dB point). fs : float Sampling rate in Hz. order : int Filter order (typically 1 for IHC models). Returns ------- tuple of torch.Tensor ``(b, a)`` where ``b`` are numerator coefficients (shape: ``(order+1,)``) and ``a`` are denominator coefficients (shape: ``(order+1,)``). Notes ----- The normalized cutoff frequency is computed as :math:`\omega_n = f_c / (f_s/2)` where :math:`f_c` is the cutoff frequency and :math:`f_s` is the sampling rate. The filter is designed using the bilinear transform (``analog=False``). """ # Normalized frequency (0 to 1, where 1 is Nyquist) wn = cutoff / (fs / 2) # Design filter b, a = butter(order, wn, btype='low', analog=False) return torch.tensor(b, dtype=self.dtype), torch.tensor(a, dtype=self.dtype)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Process the input signal through IHC envelope extraction. Applies half-wave rectification followed by Butterworth low-pass filtering. For the ``breebaart2001`` method, the filter is applied 5 times iteratively. Parameters ---------- x : torch.Tensor Input signal (typically Gammatone or DRNL filterbank output). Shape: :math:`(B, F, T)` or :math:`(F, T)` where :math:`B` is batch size, :math:`F` is frequency channels, :math:`T` is time samples. Returns ------- torch.Tensor Envelope signal. Same shape as input. Notes ----- The computation follows these steps: 1. Half-wave rectification: :math:`x_{\text{rect}} = \max(x, 0)` 2. Butterworth low-pass filtering (applied ``iterations`` times) The filter coefficients ``b`` and ``a`` are automatically moved to the same device as the input tensor. """ # Half-wave rectification x = torch.clamp(x, min=0.0) # Apply low-pass filter (possibly multiple iterations) for _ in range(self.iterations): x = self._apply_iir_filter(x, self.b.to(x.device), self.a.to(x.device)) return x
def _apply_iir_filter(self, x: torch.Tensor, b: torch.Tensor, a: torch.Tensor) -> torch.Tensor: """ Apply IIR filter along the time dimension. Processes each channel independently using Direct Form II Transposed implementation. The filter is applied sample-by-sample for each batch x channel combination. Parameters ---------- x : torch.Tensor Input signal. Shape: :math:`(B, F, T)` or :math:`(F, T)`. b : torch.Tensor Numerator coefficients. Shape: ``(nb,)``. a : torch.Tensor Denominator coefficients. Shape: ``(na,)``. Returns ------- torch.Tensor Filtered signal. Same shape as input. Notes ----- The coefficients are automatically normalized by ``a[0]`` before filtering. For 2D input :math:`(F, T)`, a batch dimension is temporarily added and then removed after filtering. """ # Normalize a0 = a[0] b = b / a0 a = a / a0 # Get dimensions original_shape = x.shape if x.ndim == 2: x = x.unsqueeze(0) # Add batch dimension batch_size, num_channels, siglen = x.shape # MPS workaround: IIR filtering with in-place indexing causes crashes on MPS # Move to CPU for filtering, then back to original device original_device = x.device needs_device_transfer = original_device.type == 'mps' if needs_device_transfer: x = x.cpu() b = b.cpu() a = a.cpu() # Flatten batch and channels for processing x_flat = x.reshape(-1, siglen) # [B*F, T] # VECTORIZED: Process all signals in parallel instead of loop y = self._lfilter_vectorized(x_flat, b, a) # Move back to original device if needed if needs_device_transfer: y = y.to(original_device) # Reshape back y = y.reshape(batch_size, num_channels, siglen) if len(original_shape) == 2: y = y.squeeze(0) return y def _lfilter_vectorized(self, x: torch.Tensor, b: torch.Tensor, a: torch.Tensor) -> torch.Tensor: r""" Apply IIR filter to multiple signals in parallel (vectorized). This is a fully vectorized implementation that processes all signals simultaneously without Python loops. For Butterworth 1st order filters (typical for IHC models), this provides significant speedup. Parameters ---------- x : torch.Tensor Input signals. Shape: ``(N, T)`` where N is number of signals, T is number of time samples. b : torch.Tensor Numerator coefficients (normalized). Shape: ``(nb,)``. a : torch.Tensor Denominator coefficients (normalized). Shape: ``(na,)``. Returns ------- torch.Tensor Filtered signals. Shape: ``(N, T)``. Notes ----- Uses scan-based approach for parallel IIR filtering. For 1st order filters (n_state=1), this is equivalent to exponential moving average which can be computed efficiently. The state update for Direct Form II Transposed is: .. math:: y[n] &= b[0] x[n] + s[n-1] \\ s[n] &= b[1] x[n] - a[1] y[n] This can be computed sample-by-sample but vectorized across all signals. """ n_b = len(b) n_a = len(a) n_state = max(n_b, n_a) - 1 if n_state == 0: # FIR filter (no feedback) - fully vectorized return b[0] * x N, T = x.shape # Initialize output and state y = torch.zeros_like(x) state = torch.zeros(N, n_state, dtype=x.dtype, device=x.device) # Process sample by sample (vectorized across all signals) # This is still a loop over time but processes all N signals in parallel for t in range(T): x_t = x[:, t] # [N] # Compute output: y[t] = b[0]*x[t] + state[0] y_t = b[0] * x_t + state[:, 0] y[:, t] = y_t # Update state vector (Direct Form II Transposed) # Build new state from scratch without referencing old state slices new_state_list = [] for i in range(n_state - 1): b_i = b[i + 1] if i + 1 < n_b else 0.0 a_i = a[i + 1] if i + 1 < n_a else 0.0 s_i = b_i * x_t - a_i * y_t + state[:, i + 1] new_state_list.append(s_i.unsqueeze(1)) # Last state element if n_state > 0: b_last = b[n_state] if n_state < n_b else 0.0 a_last = a[n_state] if n_state < n_a else 0.0 s_last = b_last * x_t - a_last * y_t new_state_list.append(s_last.unsqueeze(1)) # Concatenate to form new state (no references to old state) state = torch.cat(new_state_list, dim=1) return y def _lfilter_single(self, x: torch.Tensor, b: torch.Tensor, a: torch.Tensor) -> torch.Tensor: r""" Apply IIR filter to a single signal using Direct Form II Transposed. Implements the difference equation: .. math:: a[0] y[n] = b[0] x[n] + b[1] x[n-1] + \cdots + b[nb] x[n-nb] - a[1] y[n-1] - \cdots - a[na] y[n-na] using the Direct Form II Transposed structure for numerical stability. Parameters ---------- x : torch.Tensor Input signal. Shape: ``(T,)``. b : torch.Tensor Numerator coefficients (normalized). Shape: ``(nb,)``. a : torch.Tensor Denominator coefficients (normalized). Shape: ``(na,)``. Returns ------- torch.Tensor Filtered signal. Shape: ``(T,)``. Notes ----- The state vector has length :math:`\max(nb, na) - 1`. For a 1st-order Butterworth filter (typical for IHC models), this is 1 state variable. Assumes ``a`` and ``b`` are already normalized such that ``a[0] = 1.0``. """ n_b = len(b) n_a = len(a) n_state = max(n_b, n_a) - 1 if n_state == 0: return b[0] * x # Initialize state and output state = torch.zeros(n_state, dtype=x.dtype, device=x.device) y = torch.zeros_like(x) # Direct form II transposed for n in range(len(x)): y[n] = b[0] * x[n] + state[0] if n_state > 0 else b[0] * x[n] for i in range(n_state - 1): b_i = b[i + 1] if i + 1 < n_b else 0.0 a_i = a[i + 1] if i + 1 < n_a else 0.0 state[i] = b_i * x[n] - a_i * y[n] + state[i + 1] if n_state > 0: b_last = b[n_state] if n_state < n_b else 0.0 a_last = a[n_state] if n_state < n_a else 0.0 state[n_state - 1] = b_last * x[n] - a_last * y[n] return y
[docs] def extra_repr(self) -> str: """ Extra representation string for module printing. Returns ------- str String containing key module parameters. """ return (f"method={self.method}, fs={self.fs}, cutoff={self.cutoff} Hz, " f"order={self.order}, iterations={self.iterations}, learnable={self.learnable}")
[docs] class IHCPaulick2024(nn.Module): r""" Physiologically detailed inner hair cell transduction for CASP model. Converts basilar membrane velocity to receptor potential using a detailed physiological model of inner hair cell (IHC) mechano-electrical transduction (MET) and electrical circuit dynamics. Algorithm Overview ------------------ The IHC transduction process consists of four main stages: 1. **Stereocilia displacement scaling**: Convert BM velocity to displacement. .. math:: d_{\text{ster}}(t) = \alpha \cdot v_{\text{BM}}(t) where :math:`\alpha = 10^{-105/20}` (db2mag scaling factor). 2. **MET channel conductance**: Double-exponential sigmoid function fitted to physiological data. .. math:: G(d) = \frac{G_{\max}}{1 + \exp\left(\frac{x_0 - d}{s_1}\right) \left(1 + \exp\left(\frac{x_0 - d}{s_0}\right)\right)} where :math:`G_{\max} = 30` nS, :math:`x_0 = 20` nm (bias), :math:`s_0 = 16` nm (fast sensitivity), :math:`s_1 = 35` nm (slow sensitivity). 3. **Pre-charging**: Simulate 50 ms of activity with fixed conductance to reach steady-state resting potential :math:`V_{\text{rest}} = -57.03` mV before processing the signal. 4. **Electrical circuit ODE**: Forward Euler integration of membrane potential. .. math:: C_m \frac{dV}{dt} = I_{\text{MET}} + I_{K,f} + I_{K,s} where: * :math:`I_{\text{MET}} = -G(t) (V - E_P)` (MET current) * :math:`I_{K,f} = -G_{K,f} (V - E_{K,f})` (fast K+ current) * :math:`I_{K,s} = -G_{K,s} (V - E_{K,s})` (slow K+ current) Parameters ---------- fs : float Sampling rate in Hz. learnable : bool, optional If ``True``, all 13 physiological parameters become trainable ``nn.Parameter`` objects. If ``False``, they are registered as buffers. Default: ``False``. dtype : torch.dtype, optional Data type for computations. Recommended: ``torch.float32`` for device compatibility (MPS, CUDA), ``torch.float64`` for maximum numerical precision in ODE integration. Default: ``torch.float32``. Attributes ---------- fs : float Sampling rate in Hz. learnable : bool Whether parameters are trainable. dtype : torch.dtype Data type for computations (default: float32). scaling_factor : torch.Tensor or nn.Parameter Stereocilia displacement scaling (:math:`10^{-105/20}`). Units: dimensionless. Gmet_max : torch.Tensor or nn.Parameter Maximum MET channel conductance (30 nS). Units: Siemens (S). x0 : torch.Tensor or nn.Parameter Displacement bias (20 nm). Units: meters (m). s0 : torch.Tensor or nn.Parameter Fast sensitivity parameter (16 nm). Units: meters (m). s1 : torch.Tensor or nn.Parameter Slow sensitivity parameter (35 nm). Units: meters (m). EP : torch.Tensor or nn.Parameter Endocochlear potential (90 mV). Units: Volts (V). Cm : torch.Tensor or nn.Parameter Membrane capacitance (12.5 pF). Units: Farads (F). Gkf : torch.Tensor or nn.Parameter Fast K+ channel conductance (19.8 nS). Units: Siemens (S). Gks : torch.Tensor or nn.Parameter Slow K+ channel conductance (19.8 nS). Units: Siemens (S). Ekf : torch.Tensor or nn.Parameter Fast K+ reversal potential (-71 mV). Units: Volts (V). Eks : torch.Tensor or nn.Parameter Slow K+ reversal potential (-78 mV). Units: Volts (V). V_rest : torch.Tensor or nn.Parameter Resting potential (-57.03 mV). Units: Volts (V). G_precharge : torch.Tensor or nn.Parameter Pre-charge conductance (3.3514 nS). Units: Siemens (S). precharge_duration : float Duration of pre-charge simulation (50 ms). Units: seconds (s). Shape ----- - Input: :math:`(B, F, T)` or :math:`(F, T)` where :math:`B` is batch size, :math:`F` is frequency channels (typically 50 for CASP), :math:`T` is time samples. - Output: Same shape as input. Receptor potential in Volts, relative to pre-charge steady-state level. Notes ----- **Physiological Parameters** All parameters are derived from physiological measurements and modeling studies. The MET channel parameters (:math:`G_{\max}`, :math:`x_0`, :math:`s_0`, :math:`s_1`) are fitted to match hair cell transduction data. **Pre-charging Mechanism** The 50 ms pre-charge simulates the steady-state condition of the IHC at rest. This ensures the membrane potential starts at a physiologically realistic resting state (:math:`V_{\text{rest}} = -57.03` mV) rather than an arbitrary initial condition. The output voltage is computed relative to this pre-charge level to represent deviations from rest. **Numerical Stability** The ODE integration uses Forward Euler method with timestep :math:`\Delta t = 1/f_s`. **Float64 precision is required** to avoid accumulation of numerical errors over long signals. Using ``torch.float32`` may lead to instability or divergence. **Computational Complexity** The computational cost is :math:`O(B \cdot F \cdot (T + T_{\text{pre}}))` where :math:`T_{\text{pre}} = 0.05 \cdot f_s` is the pre-charge duration. For :math:`f_s = 44100` Hz, the pre-charge adds 2205 samples per batch x channel. Due to the sequential nature of ODE integration (each timestep depends on the previous), this implementation is **CPU-optimized**. GPU execution may not provide significant speedup and could be slower due to kernel launch overhead. **Connection to DRNL** This IHC model is designed to work with the output of :class:`DRNLFilterbank`, which provides basilar membrane velocity. The CASP pipeline is: .. code-block:: text Audio → DRNLFilterbank → IHCPaulick2024 → AdaptLoop → Modulation See Also -------- IHCEnvelope : Classical IHC envelope extraction (simpler, faster) DRNLFilterbank : Dual-resonance non-linear filterbank (typical input source) AdaptLoop : Auditory nerve adaptation (downstream processing) modfilterbank : Modulation filterbank (downstream processing) Examples -------- **Basic usage with DRNL velocity input:** >>> import torch >>> from torch_amt.common.ihc import IHCPaulick2024 >>> ihc = IHCPaulick2024(fs=44100) >>> vel = torch.randn(2, 50, 44100, dtype=torch.float64) # BM velocity from DRNL >>> V = ihc(vel) # Receptor potential >>> print(V.shape, V.dtype) torch.Size([2, 50, 44100]) torch.float64 **Batch processing with different batch sizes:** >>> ihc = IHCPaulick2024(fs=16000) >>> # Single sample >>> vel_single = torch.randn(1, 50, 16000, dtype=torch.float64) >>> V_single = ihc(vel_single) >>> # Large batch >>> vel_batch = torch.randn(8, 50, 16000, dtype=torch.float64) >>> V_batch = ihc(vel_batch) **Learnable physiological parameters for model fitting:** >>> ihc_learn = IHCPaulick2024(fs=44100, learnable=True) >>> print(f"Learnable params: {sum(p.numel() for p in ihc_learn.parameters())}") Learnable params: 13 >>> # All 13 physiological parameters can be optimized >>> optimizer = torch.optim.Adam(ihc_learn.parameters(), lr=1e-5) >>> # Note: May require constraints to ensure physiological plausibility References ---------- .. [1] L. Paulick, H. Relaño-Iborra, and T. Dau, "The Computational Auditory Signal Processing and Perception Model (CASP): A Revised Version," bioRxiv, 2024, doi: 10.1101/2024.02.02.578582. .. [2] T. Dau, B. Kollmeier, and A. Kohlrausch, "Modeling auditory processing of amplitude modulation. II. Spectral and temporal integration," *J. Acoust. Soc. Am.*, vol. 102, no. 5, pp. 2906-2919, 1997. """
[docs] def __init__(self, fs: float, learnable: bool = False, dtype: torch.dtype = torch.float32): super().__init__() self.fs = fs self.learnable = learnable self.dtype = dtype # Scaling factor: db2mag(-105) = 10^(-105/20) scaling_factor = torch.tensor(10.0 ** (-105.0 / 20.0), dtype=dtype) # MET channel parameters Gmet_max = torch.tensor(30e-9, dtype=dtype) # Max conductance (S) x0 = torch.tensor(20e-9, dtype=dtype) # Displacement bias (m) s0 = torch.tensor(16e-9, dtype=dtype) # Fast sensitivity (m) s1 = torch.tensor(35e-9, dtype=dtype) # Slow sensitivity (m) # Electrical circuit parameters EP = torch.tensor(90e-3, dtype=dtype) # Endocochlear potential (V) Cm = torch.tensor(12.5e-12, dtype=dtype) # Membrane capacitance (F) Gkf = torch.tensor(19.8e-9, dtype=dtype) # Fast K+ conductance (S) Gks = torch.tensor(19.8e-9, dtype=dtype) # Slow K+ conductance (S) Ekf = torch.tensor(-71e-3, dtype=dtype) # Fast K+ reversal potential (V) Eks = torch.tensor(-78e-3, dtype=dtype) # Slow K+ reversal potential (V) # Pre-charging parameters V_rest = torch.tensor(-0.05703, dtype=dtype) # Resting potential (V) G_precharge = torch.tensor(3.3514e-9, dtype=dtype) # Pre-charge conductance (S) precharge_duration = 50e-3 # seconds if learnable: self.scaling_factor = nn.Parameter(scaling_factor) self.Gmet_max = nn.Parameter(Gmet_max) self.x0 = nn.Parameter(x0) self.s0 = nn.Parameter(s0) self.s1 = nn.Parameter(s1) self.EP = nn.Parameter(EP) self.Cm = nn.Parameter(Cm) self.Gkf = nn.Parameter(Gkf) self.Gks = nn.Parameter(Gks) self.Ekf = nn.Parameter(Ekf) self.Eks = nn.Parameter(Eks) # V_rest: always buffer (never has gradient - used only for initialization) self.register_buffer('V_rest', V_rest) self.G_precharge = nn.Parameter(G_precharge) else: self.register_buffer('scaling_factor', scaling_factor) self.register_buffer('Gmet_max', Gmet_max) self.register_buffer('x0', x0) self.register_buffer('s0', s0) self.register_buffer('s1', s1) self.register_buffer('EP', EP) self.register_buffer('Cm', Cm) self.register_buffer('Gkf', Gkf) self.register_buffer('Gks', Gks) self.register_buffer('Ekf', Ekf) self.register_buffer('Eks', Eks) self.register_buffer('V_rest', V_rest) self.register_buffer('G_precharge', G_precharge) self.precharge_duration = precharge_duration # Pre-charge cache (only used when learnable=False for speed) # Cache is keyed by (batch_size, num_channels, device) self._precharge_cache = {} if not learnable else None
def _compute_met_conductance(self, ster_disp: torch.Tensor) -> torch.Tensor: r""" Compute MET channel conductance from stereocilia displacement. Uses a double-exponential sigmoid function fitted to physiological data from hair cell recordings: .. math:: G(d) = \frac{G_{\max}}{1 + \exp\left(\frac{x_0 - d}{s_1}\right) \left(1 + \exp\left(\frac{x_0 - d}{s_0}\right)\right)} Parameters ---------- ster_disp : torch.Tensor Stereocilia displacement in meters. Shape: :math:`(B, F, T)`. Returns ------- torch.Tensor MET channel conductance in Siemens. Shape: :math:`(B, F, T)`. Notes ----- The double-exponential form captures both fast and slow components of MET channel activation, matching physiological observations of hair cell transduction. """ factor1 = torch.exp((self.x0 - ster_disp) / self.s0) factor0 = torch.exp((self.x0 - ster_disp) / self.s1) G = self.Gmet_max / (1.0 + factor0 * (1.0 + factor1)) return G def _precharge_circuit(self, batch_size: int, num_channels: int) -> torch.Tensor: """ Pre-charge electrical circuit to steady-state resting potential. Simulates 50 ms of activity with fixed conductance (G_precharge = 3.3514 nS) to allow the membrane potential to settle to physiological resting state before processing the actual signal. Parameters ---------- batch_size : int Batch size :math:`B`. num_channels : int Number of frequency channels :math:`F`. Returns ------- torch.Tensor Pre-charged voltage in Volts. Shape: :math:`(B, F)`. Notes ----- The steady-state resting potential is approximately -57.03 mV. This pre-charging ensures consistent initial conditions across all channels and avoids transient artifacts from arbitrary initialization. The pre-charge duration (50 ms) is sufficient for the circuit to reach steady-state given the RC time constant of the model. **Performance**: Uses JIT-compiled function for 2-3x speedup. **Caching**: When learnable=False, pre-charge results are cached per (batch_size, num_channels, device) for 10-15% additional speedup. Cache is disabled in learnable mode to preserve gradient flow. """ device = self.V_rest.device # Check cache if available (learnable=False only) if self._precharge_cache is not None: cache_key = (batch_size, num_channels, device) if cache_key in self._precharge_cache: return self._precharge_cache[cache_key].clone() # Compute pre-charge using JIT-compiled version V_precharge = _precharge_circuit_jit(batch_size=batch_size, num_channels=num_channels, fs=self.fs, precharge_duration=self.precharge_duration, V_rest=self.V_rest, G_precharge=self.G_precharge, EP=self.EP, Gkf=self.Gkf, Gks=self.Gks, Ekf=self.Ekf, Eks=self.Eks, Cm=self.Cm, dtype=self.dtype, device=device) # Cache result if not learnable if self._precharge_cache is not None: self._precharge_cache[cache_key] = V_precharge.clone() return V_precharge def _solve_circuit_ode(self, G: torch.Tensor, V_precharge: torch.Tensor) -> torch.Tensor: r""" Solve electrical circuit ODE using Forward Euler integration. Computes receptor potential by integrating the membrane potential equation: .. math:: C_m \frac{dV}{dt} = I_{\text{MET}} + I_{K,f} + I_{K,s} where: * :math:`I_{\text{MET}} = -G(t) (V - E_P)` (MET current) * :math:`I_{K,f} = -G_{K,f} (V - E_{K,f})` (fast K+ current) * :math:`I_{K,s} = -G_{K,s} (V - E_{K,s})` (slow K+ current) Parameters ---------- G : torch.Tensor MET channel conductance in Siemens. Shape: :math:`(B, F, T)`. : torch.Tensor Pre-charged voltage in Volts. Shape: :math:`(B, F)`. Returns ------- torch.Tensor Receptor potential in Volts (relative to pre-charge level). Shape: :math:`(B, F, T)`. Notes ----- The Forward Euler method has timestep :math:`\Delta t = 1/f_s`. This is a first-order explicit method, numerically stable for the typical sampling rates used (16-48 kHz) given the circuit's time constants. The output is computed relative to the pre-charge level to represent deviations from resting potential rather than absolute voltage. **Performance**: Uses JIT-compiled function for 2-3x speedup. **Gradient Flow**: JIT compilation preserves full gradient flow through all parameters (EP, Gkf, Gks, Ekf, Eks, Cm). """ # Use JIT-compiled version return _solve_circuit_ode_jit(G=G, V_precharge=V_precharge, EP=self.EP, Gkf=self.Gkf, Gks=self.Gks, Ekf=self.Ekf, Eks=self.Eks, Cm=self.Cm, fs=self.fs)
[docs] def forward(self, vel: torch.Tensor) -> torch.Tensor: r""" Convert basilar membrane velocity to receptor potential. Applies the complete IHC transduction pipeline: scaling → MET conductance → pre-charging → ODE integration. Parameters ---------- vel : torch.Tensor Basilar membrane velocity in m/s (typically from :class:`DRNLFilterbank`). Shape: :math:`(B, F, T)` or :math:`(F, T)` where :math:`B` is batch size, :math:`F` is frequency channels, :math:`T` is time samples. Returns ------- torch.Tensor Receptor potential in Volts (relative to resting potential). Same shape as input. Notes ----- The processing steps are: 1. Scale velocity to stereocilia displacement 2. Compute MET channel conductance :math:`G(d)` 3. Pre-charge circuit to steady-state (50 ms simulation) 4. Integrate ODE for membrane potential For 2D input :math:`(F, T)`, a batch dimension is temporarily added and removed after processing. **Timing**: For :math:`f_s = 44100` Hz and 1 second signal with 50 channels: * Pre-charge: 2205 samples x 50 channels = 110,250 operations * Main ODE: 44100 samples x 50 channels = 2,205,000 operations * Total: ~2.3M operations per batch item """ # Handle 2D input original_shape = vel.shape if vel.ndim == 2: vel = vel.unsqueeze(0) # Add batch dimension batch_size, num_channels, n_samples = vel.shape # 1. Scale to stereocilia displacement ster_disp = self.scaling_factor * vel # 2. Compute MET channel conductance G = self._compute_met_conductance(ster_disp) # 3. Pre-charge circuit to steady state V_precharge = self._precharge_circuit(batch_size, num_channels) # 4. Solve circuit ODE V = self._solve_circuit_ode(G, V_precharge) # Restore original shape if len(original_shape) == 2: V = V.squeeze(0) return V
[docs] def extra_repr(self) -> str: """ Extra representation string for module printing. Returns ------- str String containing key module parameters. """ return (f"fs={self.fs} Hz, precharge={self.precharge_duration*1000:.1f} ms, " f"learnable={self.learnable}, dtype={self.dtype}")