Bank of gammatone auditory filters using Lyon’s all-pole approximation.
Implements parallel gammatone filters spaced on the ERB-frequency scale, which model
the bandpass filtering performed by the human cochlea. Each filter approximates the
impulse response:
where \(t \geq 0\), \(n\) is the filter order (typically 4), \(\beta\)
is the bandwidth parameter, \(f_c\) is the center frequency, and \(\phi\) is
the phase offset.
The all-pole approximation factorizes the filter as a cascade of first-order complex
resonators, which provides:
Numerical stability: No polynomial expansion required
Accurate low-frequency response: No pole magnitude limiting needed
Efficient computation: Cascade structure with \(n\) first-order sections
Two implementations are provided:
‘sos’ (default): Cascade of first-order sections. Numerically stable, no
frequency distortion. Recommended for all applications.
‘poly’: Polynomial expansion with pole limiting (MAX_POLE_MAG=0.9). Legacy
implementation that causes frequency shifts for low-frequency channels.
>>> importtorch>>> fromtorch_amt.common.filterbanksimportGammatoneFilterbank>>>>>> # Create filterbank with 30 channels from 100 to 8000 Hz>>> fb=GammatoneFilterbank((100.0,8000.0),fs=16000,n=4)>>> print(fb.num_channels)30>>> print(fb.fc[:3])# First three center frequenciestensor([100.0000, 138.6141, 181.7625])>>>>>> # Process audio signal>>> x=torch.randn(2,16000)# (batch=2, time=16000)>>> y=fb(x)>>> print(y.shape)# (batch=2, channels=30, time=16000)torch.Size([2, 30, 16000])
Custom center frequencies:
>>> # Specify exact center frequencies>>> fc_custom=torch.tensor([200.0,500.0,1000.0,2000.0,4000.0])>>> fb_custom=GammatoneFilterbank(fc_custom,fs=16000)>>> print(fb_custom.num_channels)5
Learnable filters for optimization:
>>> fb_learnable=GammatoneFilterbank(... (100.0,8000.0),fs=16000,learnable=True... )>>> # Count learnable parameters (2*F for poles+gains in SOS mode)>>> n_params=sum(p.numel()forpinfb_learnable.parameters())>>> print(f"Learnable parameters: {n_params}")Learnable parameters: 60
Comparing implementations:
>>> fb_sos=GammatoneFilterbank((100.0,8000.0),fs=16000,implementation='sos')>>> fb_poly=GammatoneFilterbank((100.0,8000.0),fs=16000,implementation='poly')>>> x_test=torch.randn(1,8000)>>> y_sos=fb_sos(x_test)>>> y_poly=fb_poly(x_test)>>> # SOS implementation is more accurate at low frequencies>>> print(f"SOS output shape: {y_sos.shape}")SOS output shape: torch.Size([1, 30, 8000])
Notes
The IIR filtering operation (_apply_iir()) processes samples sequentially within
each channel, but channels are processed in parallel, making GPU/Metal acceleration effective
for multi-channel processing.
Learnable Parameters:
When learnable=True:
SOS mode: 2F parameters (F complex poles + F complex gains) = 4F real values
Poly mode: Fx(nb + na) complex coefficients
where F is the number of channels. For a typical 30-channel filterbank with n=4:
Stable even for very narrow-band low-frequency filters
The poly implementation has known issues:
Requires pole limiting (MAX_POLE_MAG=0.9) for stability
Causes frequency shifts for low-frequency channels (< 500 Hz)
Provided only for compatibility with legacy MATLAB code
Filter Bandwidth:
The default betamul=1.019 gives an equivalent rectangular bandwidth (ERB) that matches
human auditory filter bandwidths according to Glasberg & Moore (1990). The 3-dB
bandwidth is approximately:
fc (Tensor | Tuple[float, float]) – Center frequencies. If tuple (flow, fhigh), generates ERB-spaced frequencies
between flow and fhigh using erbspacebw() with bwmul=1.0.
where ERB is the Equivalent Rectangular Bandwidth from Glasberg & Moore (1990).
If betamul is None, the standard Patterson et al. (1987) value of 1.019 is used.
Filters the input signal through all frequency channels in parallel, producing
a multi-channel output representing the cochlear frequency decomposition.
This method dispatches to either _forward_sos() or _forward_poly()
based on the implementation parameter set during initialization.
The output is real-valued and represents the instantaneous envelope at each
frequency channel. The MATLAB convention of multiplying by 2 is applied to
compensate for using only the analytic signal’s real part.
Computational Complexity:
SOS: O(nBFT) where n is filter order
Poly: O(BFT(nb + na)) where nb, na are coefficient lengths
For typical parameters (n=4, B=2, F=30, T=16000), expect ~2M operations.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Dual Resonance Non-Linear (DRNL) filterbank for basilar membrane simulation.
Implements the DRNL auditory filterbank model used in Paulick et al. (2024)
CASP model. The DRNL consists of two parallel signal paths (linear and nonlinear)
that are summed to produce the basilar membrane velocity response.
The architecture models the cochlea’s dual mechanism for sound processing:
Linear path: Provides level-independent frequency selectivity via
gammatone bandpass filtering and lowpass smoothing
Nonlinear path: Provides level-dependent compression via broken-stick
nonlinearity sandwiched between gammatone filters
>>> importtorch>>> fromtorch_amt.common.filterbanksimportDRNLFilterbank>>>>>> # Create 50-channel DRNL from 250 to 8000 Hz>>> drnl=DRNLFilterbank((250,8000),fs=44100,n_channels=50)>>> print(drnl.num_channels)50>>>>>> # Process 1 second of audio>>> x=torch.randn(44100)>>> y=drnl(x)>>> print(y.shape)torch.Size([50, 44100])
Sets up DRNL filterbank with specified center frequencies, computes
CF-dependent parameters (gains, bandwidths, nonlinearity coefficients),
and precomputes filter coefficients for efficient processing.
If torch.Tensor: Explicit center frequencies in Hz, shape (n,).
User has full control over frequency spacing and count.
If tuple(flow,fhigh): Auto-generate n_channels ERB-spaced
frequencies between flow and fhigh Hz using erbspacebw().
fs (float) – Sampling rate in Hz. Typical values: 44100, 48000, or 32000.
Must be at least 2x the highest center frequency to avoid aliasing.
n_channels (int) – Number of frequency channels. Default: 50.
Only used when fc is a tuple. Determines ERB spacing resolution:
more channels = finer frequency resolution, higher computational cost.
’NH’: Normal hearing with intact cochlear compression.
Nonlinearity parameters (a, b, c) computed from Paulick et al. (2024).
At CF <= 1000 Hz: frequency-dependent compression. Above 1000 Hz:
frozen at 1500 Hz parameters.
’HIx’: Hearing impaired without cochlear compression.
Linear nonlinearity (a=0, removes compression term). Models
outer hair cell dysfunction.
Model parametrization variant. Default: ‘paulick2024’.
’paulick2024’: Current CASP version (Paulick et al. 2024).
Linear path: 2 cascaded gammatone + 4 cascaded lowpass.
Nonlinear path: 2 cascaded gammatone + 1 lowpass.
’jepsen2008’: Previous version (Jepsen et al. 2008).
Linear path: 3 cascaded gammatone + 4 cascaded lowpass.
Nonlinear path: 3 cascaded gammatone + 3 cascaded lowpass.
learnable (bool) – If True, makes CF-dependent parameters (CF_lin, BW_lin, g, CF_nlin,
BW_nlin, a, b, c) into nn.Parameter for gradient-based optimization.
Default: False (parameters are buffers, not optimized).
dtype (dtype) – Data type for computations and parameters. Default: torch.float64.
Double precision recommended for numerical stability in cascaded
IIR filtering (gammatone and Butterworth filters).
Raises:
ValueError – If model is not ‘paulick2024’ or ‘jepsen2008’.
When fc is a tuple, frequencies are spaced according to the
Equivalent Rectangular Bandwidth (ERB) scale:
\[\text{ERB}(f) = 24.7 (4.37 f / 1000 + 1)\]
This spacing matches the frequency resolution of the human auditory system,
with finer spacing at low frequencies and coarser spacing at high frequencies.
Parameter Initialization:
During initialization, the following steps occur:
Generate or validate center frequencies (self.fc)
Compute CF-dependent parameters (_compute_parameters):
- Linear path: CF_lin, BW_lin, LP_lin_cutoff, g
- Nonlinear path: CF_nlin, BW_nlin, LP_nlin_cutoff, a, b, c
Store coefficients as lists for scipy.signal.lfilter compatibility
Computational Cost:
Filter coefficient computation is O(n_channels), performed once at
initialization. Forward pass cost is O(n_channels * batch * time),
dominated by sequential filtering operations.
Model Variants:
The two model variants differ in filter orders, affecting frequency
selectivity and computational cost:
Higher orders = sharper frequency tuning but more cascaded filtering.
Processes audio through dual-path (linear + nonlinear) DRNL filterbank,
computing basilar membrane velocity response for each frequency channel.
Each channel applies independent linear and nonlinear processing, then
sums the paths.
Linear regime (\(a|x|\) term): Dominates at low levels
Compressive regime (\(b|x|^c\) term): Dominates at high levels,
with \(c \\approx 0.25\) providing ~4:1 compression
Transition: Occurs at \(|x| = (a/b)^{1/(c-1)}\)
For normal hearing (NH), this creates level-dependent gain: high gain
for quiet sounds, reduced gain for loud sounds. For hearing impaired (HIx),
\(a=0\) removes the linear term, leaving only compression.
Computational Complexity:
For input of length T samples, B batches, F channels:
Filtering operations: O(B * F * T * (n_gt + n_lp)) per path
Nonlinearity: O(B * F * T)
Total: O(B * F * T * max_filter_order)
For B=2, F=50, T=44100, n_gt=2, n_lp=4: ~26M operations (~200 ms on CPU)
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Fast DRNL filterbank using FFT-based convolution with pre-computed impulse responses.
This is a performance-optimized drop-in replacement for DRNLFilterbank
that achieves ~100x speedup while maintaining excellent numerical accuracy
(max absolute diff < 3e-4).
Key Optimizations:
Pre-compute impulse responses from all IIR filters at initialization
Zero loops in forward pass - all operations vectorized
Pre-flip impulse responses (conv1d does cross-correlation)
Performance Characteristics:
Speedup: ~100x vs original DRNLFilterbank
Throughput: ~0.5x realtime on CPU (vs ~0.005x for original)
Accuracy: max absolute diff < 3e-4, mean diff < 1e-5
Memory: ~5 MB for IR storage (50 channels, ir_length=4096)
Batch scaling: Best with batch size ≤ 4, degrades at batch=8
Training Considerations:
Filter parameters (CF, BW, LP_cutoff) are frozen (not trainable) because
impulse responses are pre-computed. Only nonlinearity parameters (a, b, c)
remain trainable. This is a design trade-off for performance.
If you need to train filter parameters, use the original DRNLFilterbank.
The ~100x speedup is measured relative to the original IIR-based implementation.
However, the original is ~200x slower than realtime, so FastDRNLFilterbank is
still ~2x slower than realtime on typical CPUs. For real-time applications or
very large-scale training, consider:
Using batch size ≤ 4 (batch=8 has poor scaling)
Processing audio in chunks
Using GPU acceleration (if available)
Using a shorter ir_length (e.g., 2048) for speed vs accuracy trade-off
Training Limitations:
Because impulse responses are pre-computed at initialization, filter parameters
(center frequencies, bandwidths, lowpass cutoffs) cannot be trained via gradient
descent. Only the nonlinearity parameters (a, b, c) remain trainable.
If you need fully trainable filters, use DRNLFilterbank. If you only
need to optimize the nonlinearity while keeping filters fixed, FastDRNLFilterbank
is the better choice.
Numerical Accuracy:
The FFT-based convolution introduces small numerical differences compared to
the original IIR implementation:
Maximum absolute difference: ~3e-4
Mean absolute difference: ~1e-5
Relative error: < 0.01%
These differences are negligible for most applications and are well below
numerical precision limits of auditory models.
>>> importtorch>>> fromtorch_amt.common.filterbanksimportFastDRNLFilterbank,DRNLFilterbank>>>>>> # Original (slow)>>> drnl=DRNLFilterbank((250,8000),fs=44100,n_channels=50)>>>>>> # Fast version (100x speedup)>>> drnl_fast=FastDRNLFilterbank((250,8000),fs=44100,n_channels=50)>>>>>> x=torch.randn(1,22050)# 0.5s @ 44.1kHz>>> y=drnl_fast(x)# [1, 50, 22050]>>> print(f"Output shape: {y.shape}")Output shape: torch.Size([1, 50, 22050])
With trainable nonlinearity parameters:
>>> drnl=FastDRNLFilterbank((250,8000),fs=44100,n_channels=50,learnable=True)>>> optimizer=torch.optim.Adam(drnl.parameters(),lr=1e-3)>>>>>> # Only a, b, c will be updated (3*50 = 150 parameters)>>> print(f"Trainable parameters: {sum(p.numel()forpindrnl.parameters())}")Trainable parameters: 150>>>>>> y=drnl(x)>>> loss=criterion(y,target)>>> loss.backward()>>> optimizer.step()
Adjusting IR length for speed/accuracy trade-off:
>>> # Shorter IR = faster but less accurate>>> drnl_short=FastDRNLFilterbank((250,8000),fs=44100,ir_length=2048)>>>>>> # Longer IR = slower but more accurate>>> drnl_long=FastDRNLFilterbank((250,8000),fs=44100,ir_length=8192)
Sets up DRNL filterbank with specified center frequencies, computes
CF-dependent parameters (gains, bandwidths, nonlinearity coefficients),
and precomputes filter coefficients for efficient processing.
If torch.Tensor: Explicit center frequencies in Hz, shape (n,).
User has full control over frequency spacing and count.
If tuple(flow,fhigh): Auto-generate n_channels ERB-spaced
frequencies between flow and fhigh Hz using erbspacebw().
fs (float) – Sampling rate in Hz. Typical values: 44100, 48000, or 32000.
Must be at least 2x the highest center frequency to avoid aliasing.
n_channels (int, optional) – Number of frequency channels. Default: 50.
Only used when fc is a tuple. Determines ERB spacing resolution:
more channels = finer frequency resolution, higher computational cost.
’NH’: Normal hearing with intact cochlear compression.
Nonlinearity parameters (a, b, c) computed from Paulick et al. (2024).
At CF <= 1000 Hz: frequency-dependent compression. Above 1000 Hz:
frozen at 1500 Hz parameters.
’HIx’: Hearing impaired without cochlear compression.
Linear nonlinearity (a=0, removes compression term). Models
outer hair cell dysfunction.
Model parametrization variant. Default: ‘paulick2024’.
’paulick2024’: Current CASP version (Paulick et al. 2024).
Linear path: 2 cascaded gammatone + 4 cascaded lowpass.
Nonlinear path: 2 cascaded gammatone + 1 lowpass.
’jepsen2008’: Previous version (Jepsen et al. 2008).
Linear path: 3 cascaded gammatone + 4 cascaded lowpass.
Nonlinear path: 3 cascaded gammatone + 3 cascaded lowpass.
learnable (bool, optional) – If True, makes CF-dependent parameters (CF_lin, BW_lin, g, CF_nlin,
BW_nlin, a, b, c) into nn.Parameter for gradient-based optimization.
Default: False (parameters are buffers, not optimized).
dtype (torch.dtype, optional) – Data type for computations and parameters. Default: torch.float64.
Double precision recommended for numerical stability in cascaded
IIR filtering (gammatone and Butterworth filters).
Raises:
ValueError – If model is not ‘paulick2024’ or ‘jepsen2008’.
When fc is a tuple, frequencies are spaced according to the
Equivalent Rectangular Bandwidth (ERB) scale:
\[\text{ERB}(f) = 24.7 (4.37 f / 1000 + 1)\]
This spacing matches the frequency resolution of the human auditory system,
with finer spacing at low frequencies and coarser spacing at high frequencies.
Parameter Initialization:
During initialization, the following steps occur:
Generate or validate center frequencies (self.fc)
Compute CF-dependent parameters (_compute_parameters):
- Linear path: CF_lin, BW_lin, LP_lin_cutoff, g
- Nonlinear path: CF_nlin, BW_nlin, LP_nlin_cutoff, a, b, c
Store coefficients as lists for scipy.signal.lfilter compatibility
Computational Cost:
Filter coefficient computation is O(n_channels), performed once at
initialization. Forward pass cost is O(n_channels * batch * time),
dominated by sequential filtering operations.
Model Variants:
The two model variants differ in filter orders, affecting frequency
selectivity and computational cost:
Higher orders = sharper frequency tuning but more cascaded filtering.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Excitation pattern with asymmetric spreading for Glasberg & Moore (2002) loudness model.
Applies frequency-domain spreading to simulate the spread of excitation along
the basilar membrane. The spreading function models how energy at one frequency
activates adjacent auditory filters, with asymmetric and level-dependent characteristics:
Asymmetric: More spreading toward lower frequencies (shallower slope)
than higher frequencies (steeper slope)
Level-dependent: Slopes decrease with increasing level (more spreading
at high levels)
This is the second stage of the Glasberg & Moore (2002) loudness model, applied
after ERB integration (ERBIntegration) to compute the excitation pattern
from the excitation in individual ERB bands.
Parameters:
fs (int) – Sampling rate in Hz. Default: 32000. Used for consistency with
ERBIntegration but not directly used in spreading computation.
f_min (float) – Minimum center frequency in Hz. Default: 50.0. Defines lower bound
of ERB channel range.
f_max (float) – Maximum center frequency in Hz. Default: 15000.0. Defines upper bound
of ERB channel range.
erb_step (float) – ERB-rate spacing step. Default: 0.25. Determines ERB channel resolution.
Must match ERBIntegration for consistent processing.
learnable (bool) – If True, spreading slope parameters become learnable nn.Parameter
objects (4 parameters: upper/lower base slopes, upper/lower level
dependencies). Default: False (fixed slopes from Moore & Glasberg 1987).
The spreading slopes depend on excitation level following Moore & Glasberg (1987):
\[\begin{split}s_u(E) &= s_{u,\text{base}} + k_u (E - E_{\text{ref}}) \\
s_l(E) &= s_{l,\text{base}} + k_l (E - E_{\text{ref}})\end{split}\]
where \(E_{\text{ref}} = 60\) dB SPL is the reference level, \(s_{u,\text{base}} = 27\) dB/ERB
and \(s_{l,\text{base}} = 11\) dB/ERB are base slopes, and \(k_u = -0.37\), \(k_l = -0.20\)
are level dependencies (slopes decrease at high levels -> more spreading).
Slopes are clamped to minimum 10.0 dB/ERB to prevent unrealistic spreading at very high levels.
Asymmetry Rationale:
The asymmetric spreading (\(s_u > s_l\)) reflects physiological properties
of the cochlea:
Upward spreading (toward high frequencies): Steeper decay because
high-frequency channels are physically distant from excitation site
on basilar membrane.
Downward spreading (toward low frequencies): Shallower decay because
traveling wave on basilar membrane spreads more gradually toward apex
(low-frequency region).
At 60 dB SPL, asymmetry ratio is \(27/11 \approx 2.45:1\).
Computational Complexity:
The triple-nested loop (batch, frames, channels) with inner loop over all
channels gives complexity:
Time: O(batch x n_frames x n_erb_bands²)
For (2, 32, 150): ~1.44M operations
Optimization: Skip contributions with attenuation >50 dB (negligible)
Typical runtime: ~50 ms on CPU, ~5 ms on GPU (for 2 batches, 32 frames)
Spreading Limits:
Contributions with attenuation >50 dB are skipped (threshold check) to avoid:
Numerical instability in log-domain computations
Wasted computation on negligible contributions
At 27 dB/ERB upper slope: >50 dB attenuation at ~1.85 ERB distance
At 11 dB/ERB lower slope: >50 dB attenuation at ~4.55 ERB distance
Log-Domain Summation:
Uses torch.logaddexp(a,b) to compute \(\log(e^a + e^b)\) for
dB values without converting to linear domain:
\[\text{logaddexp}(a, b) = \log_{10}(10^{a/10} + 10^{b/10}) \cdot 10\]
This prevents numerical overflow/underflow when summing many contributions.
Device Support:
Supports CPU, CUDA, and MPS. All buffers (fc_erb, erb_centers,
level_ref) and parameters automatically moved with .to(device).
Apply excitation pattern spreading to ERB-band excitation.
Parameters:
excitation (Tensor) – Excitation in dB SPL, shape (batch, n_frames, n_erb_bands).
Typically output from ERBIntegration.
Returns:
Spread excitation in dB SPL, shape (batch, n_frames, n_erb_bands).
Represents excitation pattern after accounting for frequency
spreading along the basilar membrane.
The upper slope is always steeper than the lower slope, reflecting
the physiological asymmetry of basilar membrane vibration patterns.
Level Dependency:
Both slopes decrease with increasing level (negative coefficients),
meaning more spreading at high levels. This matches auditory filter
tuning measurements showing broader filters at high SPLs.
Examples
>>> fromtorch_amt.common.filterbanksimportExcitationPattern>>> exc_pattern=ExcitationPattern()>>>>>> # Compare slopes at different levels>>> forlevelin[40,60,80,100]:... upper,lower=exc_pattern.get_spreading_slopes(level)... print(f"{level} dB SPL: upper={upper:.1f}, lower={lower:.1f}, ratio={upper/lower:.2f}:1")40 dB SPL: upper=34.4, lower=15.0, ratio=2.29:160 dB SPL: upper=27.0, lower=11.0, ratio=2.45:180 dB SPL: upper=19.6, lower=10.0, ratio=1.96:1100 dB SPL: upper=12.2, lower=10.0, ratio=1.22:1
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Excitation pattern computation using roex (rounded-exponential) auditory filters.
Converts a sparse spectrum (discrete frequency components with levels) into
an excitation pattern on the ERB-rate scale using rounded-exponential (roex)
filters with level-dependent lower slopes, following Moore et al. (2016).
The roex filter models the auditory filter shape with asymmetric slopes that
depend on the input sound level, capturing the nonlinear frequency selectivity
of the human cochlea. Each spectral component spreads to nearby ERB channels
according to the roex weighting function.
erb_step (float) – ERB-rate step size between channels. Default: 0.25 (creates 150 channels
from 1.75 to 39.0).
spreading_limit_octaves (float) – Maximum spreading distance in octaves (±). Default: 4.0.
Components more than this distance from a channel center frequency
are ignored to reduce computational cost.
learnable (bool) – If True, makes level_dep_factor (0.35) a learnable parameter.
Default: False.
dtype (dtype) – Data type for computations. Default: torch.float32.
Level-dependent slope factor (default 0.35). Registered as buffer
or parameter depending on learnable. When learnable, clamped
to [0.0, 1.0] for numerical stability.
At high levels (X > 51 dB): \(p_l < p\) (broader lower slope)
The reference level of 51 dB SPL and slope ratio \(p(f)/p(1000)\)
ensure consistent scaling across frequencies.
ERB Scale Computation:
Channels are spaced uniformly on the ERB-rate scale from erb_lower to
erb_upper with step erb_step. The ERB-rate is converted to frequency
using erbrate2f():
This spacing matches the critical band scale of human hearing, with finer
resolution at low frequencies.
Spreading and Computational Efficiency:
To reduce computation, spreading is limited to ±``spreading_limit_octaves``
from each channel. For the default 4.0 octaves, only channels within
\(f_c / 16\) to \(16 f_c\) receive contributions from a component
at frequency \(f_c\).
Output Units:
The excitation pattern is returned in dB, computed as:
where \(L_i\) is the level of component \(i\) and \(W_i[n]\)
is the roex weighting from component \(i\) to channel \(n\).
Linear power summation is performed before converting to dB.
Maximum spreading distance in octaves (±). Default: 4.0.
Limits excitation spreading to reduce computation. A component at
frequency \(f\) only affects channels in the range
\([f/2^4, f \cdot 2^4] = [f/16, 16f]\). Smaller values
reduce accuracy but increase speed. Typical range: [2.0, 6.0].
learnable (bool) – If True, makes level_dep_factor (0.35) a learnable
nn.Parameter for gradient-based optimization. Default: False
(fixed buffer).
dtype (dtype) – Data type for computations and parameters. Default: torch.float32.
Use torch.float64 for higher precision if needed.
Raises:
ValueError – If erb_lower>=erb_upper (invalid range).
This value normalizes the level-dependent slope adjustment across frequencies.
Computational Cost:
Initialization is O(n_channels), typically ~150 channels requiring
~500 operations (ERB conversion + slope computation). This is negligible
compared to forward pass cost.
Compute excitation pattern from sparse spectrum (VECTORIZED).
Takes frequency components with their levels and spreads them to
ERB-spaced channels using roex filters with level-dependent slopes.
Optimization Note: This implementation is fully vectorized, processing
all (batch x components) in parallel instead of nested loops. This provides
~12x speedup compared to the original implementation while maintaining
numerical accuracy within 0.001 dB (imperceptible difference).
For each spectral component \(i\) at frequency \(f_i\) with level
\(L_i\), the contribution to channel \(n\) with center frequency
\(f_c[n]\) is computed as:
where \(\epsilon = 10^{-12}\) prevents log(0) errors. Channels with
no contributions have \(E_{\text{linear}} \approx 0\), resulting in
\(E_{\text{dB}} \approx -120\) dB (effectively silent).
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
ERB-scale integration for Glasberg & Moore (2002) loudness model.
Integrates power spectral density (PSD) within Equivalent Rectangular Bandwidth
(ERB) spaced frequency bands to compute the excitation pattern. This transforms
the FFT-based frequency representation into an auditory-motivated frequency scale
that better represents human frequency discrimination.
The module computes 150 ERB-spaced channels from 50 Hz to 15000 Hz with 0.25
ERB-rate spacing, following Glasberg & Moore (2002). Each ERB band integrates
PSD energy using rectangular filters centered at the ERB channel frequencies.
Parameters:
fs (int) – Sampling rate in Hz. Default: 32000. Determines Nyquist frequency for PSD
integration.
f_min (float) – Minimum center frequency in Hz. Default: 50.0. Lower bound of ERB scale
matching typical auditory models.
f_max (float) – Maximum center frequency in Hz. Default: 15000.0. Upper bound avoids
aliasing artifacts and matches loudness model requirements.
erb_step (float) – ERB-rate spacing step. Default: 0.25. Determines frequency resolution:
smaller values give finer resolution but more channels. 0.25 yields
approximately 150 channels matching Glasberg & Moore (2002).
learnable (bool) – If True, integration weights become learnable nn.Parameter objects
(one weight per ERB band). Default: False (unit weights).
Global bandwidth scale factor (scalar). Multiplies all ERB bandwidths.
When learnable=False, fixed to 1.0. When learnable=True, optimizable.
Clamped to [0.1, 10.0] during forward pass for numerical stability.
Allows optimization of effective filter bandwidth for better modeling.
Convert to dB SPL: \(E_i = 10 \log_{10}(P_i / p_{\text{ref}}^2)\)
where \(\Delta f\) is PSD bin width, \(w_i\) is integration weight,
and \(p_{\text{ref}} = 20 \mu\text{Pa}\) is the standard acoustic reference pressure.
dB SPL Conversion:
The output excitation is in dB SPL assuming:
Input PSD is calibrated in Pa²/Hz
Reference pressure \(p_{\text{ref}} = 20 \times 10^{-6}\) Pa (20 μPa)
Convert to dB SPL: \(E_i = 10 \log_{10}(P_i / p_{\text{ref}}^2)\)
where \(\Delta f\) is PSD bin width (freqs[1] - freqs[0]), \(w_i\)
is integration weight (learnable or 1.0), and \(p_{\text{ref}} = 20 \mu\text{Pa}\).
Rectangular Filters:
This implementation uses simple rectangular frequency masking. More
sophisticated models (e.g., roex filters) provide better approximation
of auditory filter shapes but are computationally more expensive.
dB SPL Calibration:
Output is in dB SPL (Sound Pressure Level) assuming:
Input PSD is calibrated in Pa²/Hz
Reference pressure: 20 μPa (standard for airborne sound)
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Multi-resolution FFT analysis for Glasberg & Moore (2002) loudness model.
Implements frequency-dependent time-frequency resolution by computing multiple STFTs
with different window lengths and selecting the appropriate window for each frequency
range. This achieves:
Long windows (better frequency resolution) for low frequencies
Short windows (better temporal resolution) for high frequencies
The window selection follows the original MATLAB implementation in the AMT
(Auditory Modeling Toolbox), which uses 6 Hann windows ranging from 2 to 64 ms.
Parameters:
fs (int) – Sampling rate in Hz. Default: 32000 (required by Glasberg & Moore 2002).
Note: The model is designed for 32 kHz sampling rate. Other rates may
produce incorrect results due to hardcoded frequency thresholds.
window_lengths (list | None) – FFT window lengths in samples. Default: [2048, 1024, 512, 256, 128, 64]
corresponding to [64, 32, 16, 8, 4, 2] ms at 32 kHz sampling rate.
hop_fraction (float) – Hop size as fraction of window length. Default: 0.5 (50% overlap).
Each window uses its own hop length = window_length * hop_fraction.
learnable (bool) – If True, frequency selection thresholds become learnable nn.Parameter objects.
Default: False (fixed thresholds from Glasberg & Moore 2002).
where \(w\) is the window function and \(f_s\) is the sampling rate.
This gives units of Pa²/Hz when the input is in Pascals.
Learnable Parameters:
When learnable=True, the 5 frequency thresholds (10 real values total for
2 boundaries per threshold) become learnable. This allows the model to adapt
the window selection strategy during training.
Computational Complexity:
FFT operations: O(NWlog(W)) where N is windows, W is window length
Interpolation: O(BTF) where B is batch, T is time, F is frequency bins
Total: dominated by FFT for typical parameters
For 1 second at 32 kHz with 6 windows: ~70M operations.
MATLAB Verification:
This implementation has been verified against the MATLAB reference:
glasberg2002.m: Main loudness model (lines 107-159)
Frequency boundaries: [80, 500, 1250, 2540, 4050] Hz from vLimitingIndices (line 25)
FFT length: 2048 samples (line 24)
Hop size: 1 ms = 32 samples (timeStep, line 27)
Note: The Python implementation uses slightly different thresholds
[500, 1000, 2000, 4000, 8000] Hz to better match the published paper,
while MATLAB uses [80, 500, 1250, 2540, 4050] Hz for backward compatibility.
Both produce perceptually similar results.
If True, the 5 frequency thresholds [500, 1000, 2000, 4000, 8000] Hz
become trainable nn.Parameter objects, allowing gradient-based
optimization of the window selection strategy.
Default: False (fixed thresholds from Glasberg & Moore 2002).
Raises:
ValueError – If hop_fraction is not in range (0, 1].
ValueError – If window_lengths contains non-positive integers.
Notes
Frequency-Window Mapping:
The default configuration assigns 6 windows to 6 frequency ranges via
thresholds [500, 1000, 2000, 4000, 8000] Hz:
This design trades off time-frequency resolution to match human auditory
perception characteristics across frequency ranges.
Pre-computed Buffers:
All Hann windows are pre-computed and registered as buffers named
window_{length} (e.g., window_2048, window_1024) to ensure:
Reproducibility across devices (CPU, GPU, MPS)
Efficient memory usage (computed once, reused for all STFTs)
Proper gradient flow (buffers move with .to(device))
Window power (sum of squared values) is also pre-computed for PSD normalization.
Learnable Thresholds:
When learnable=True, the model can adapt the window selection during
training. This is useful for tasks where the optimal time-frequency trade-off
differs from psychoacoustic models (e.g., speech enhancement, music analysis).
The thresholds are initialized to [500, 1000, 2000, 4000, 8000] Hz and
constrained to remain sorted via projection during optimization.
Apply multi-resolution FFT analysis to audio signal.
Computes STFTs with all configured window lengths, then selects the appropriate
window for each frequency range to produce a single time-frequency representation
with frequency-dependent resolution.
Parameters:
signal (Tensor) – Input audio signal, shape (batch, time). Can be any dtype; will be processed
as-is by torch.stft.
psd (torch.Tensor) – Power spectral density, shape (batch, n_frames, n_freq_bins). Units: Pa²/Hz
when input is in Pascals (typical for audio processing).
freqs (torch.Tensor) – Frequency vector in Hz, shape (n_freq_bins,). Linearly spaced from 0 to
Nyquist frequency (fs/2).
Notes
Processing Pipeline:
Compute 6 STFTs with different window lengths
For each STFT, compute PSD in its designated frequency range
Interpolate all PSDs to common frequency grid (longest window)
Merge PSDs by frequency range to create final output
Time Resolution:
n_frames depends on the longest window and hop_fraction:
Get frequency-to-window mapping for visualization and analysis.
Returns the window index used for each frequency bin in the reference (longest
window) frequency grid. Useful for understanding and visualizing how the
multi-resolution analysis distributes different windows across frequencies.
freqs (torch.Tensor) – Frequency vector in Hz, shape (n_freq_bins,), where n_freq_bins =
longest_window // 2 + 1. Linearly spaced from 0 to Nyquist.
window_indices (torch.Tensor) – Window index for each frequency, shape (n_freq_bins,), dtype torch.long.
Values range from 0 to (n_windows - 1), where 0 corresponds to the
longest window and (n_windows - 1) to the shortest.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Multi-resolution spectral analysis for Moore et al. (2016/2018) binaural loudness model.
Implements frequency-dependent time-frequency analysis using 6 Hann windows with
lengths from 2 to 64 ms, computing sparse spectral representations with only
“relevant” frequency components. This differs fundamentally from
MultiResolutionFFT by producing sparse output (frequency/level pairs)
rather than dense power spectral densities, enabling efficient loudness computation
for time-varying binaural signals.
The model applies 6 overlapping FFT analyses with window lengths matched to the
temporal and frequency resolution requirements of auditory perception across the
frequency spectrum. Each window contributes only to specific frequency ranges,
and only components exceeding perceptual relevance thresholds are retained.
Parameters:
fs (int) – Sampling rate in Hz. Must be 32000 as required by Moore et al. (2016/2018).
Other sampling rates will raise ValueError. Default: 32000.
segment_duration (int) – Duration of each analysis segment in milliseconds. Default: 1 (as specified
in Moore et al. 2018). This determines the time resolution of the output
(1 ms per frame).
db_max (float) – Reference SPL level in dB for full-scale sinusoid. Default: 93.98, which
corresponds to dbspl(1) in the AMT MATLAB implementation. This value
calibrates intensity units so that 1000.0 corresponds to 0 dB SPL.
learnable (bool) – If True, frequency band limits (freq_limits) and relevance thresholds
(threshold_max_minus, threshold_absolute) become learnable
nn.Parameter objects for optimization. Default: False (fixed parameters
from Moore et al. 2016/2018).
dtype (dtype) – Data type for internal computations and parameters. Default: torch.float32.
Use torch.float64 for higher precision if needed.
Frequency band limits for each window, shape (6, 2). Each row contains
[f_low, f_high] in Hz defining which frequency range each window analyzes.
Default: [[20, 80], [80, 500], [500, 1250], [1250, 2540], [2540, 4050],
[4050, 15000]].
Custom segment duration for coarser temporal resolution:
>>> # Use 2 ms segments instead of 1 ms (faster, less temporal detail)>>> spectrum_2ms=Moore2016Spectrum(fs=32000,segment_duration=2)>>> audio_short=torch.randn(1,2,16000)# 500 ms>>> freqs_l,levels_l,freqs_r,levels_r=spectrum_2ms(audio_short)>>> print(f"Frames (2ms resolution): {freqs_l.shape[1]}")Frames (2ms resolution): 218
Notes
Frequency-Dependent Window Assignment:
Each FFT window analyzes a specific frequency range matched to auditory
temporal/frequency resolution tradeoffs:
Sparse Output Format:
Unlike dense spectrograms, this module returns sparse representations:
only frequency components exceeding relevance criteria are included. This
reduces memory and computational requirements for downstream loudness
calculations. Zero frequencies indicate padding (no component).
The number of relevant components varies per time frame and depends on
signal characteristics. Frames with simple spectra may have ~100 components,
while complex signals may have ~900 components (out of 1024 possible bins).
Relevant Component Criteria:
A frequency component is retained if it satisfies both conditions:
Relative threshold:\(I > I_{\max} / 10^6\) (i.e., within 60 dB of maximum)
Absolute threshold:\(I > 10^{-3}\) (i.e., above -30 dB SPL when calibrated)
where \(I\) is intensity in calibrated linear units. The absolute
threshold ensures inaudible components are discarded even if they dominate
the spectrum (e.g., DC offset, low-frequency rumble).
fs (int) – Sampling rate in Hz. Must be 32000 (strictly enforced).
Default: 32000.
segment_duration (int) – Analysis segment duration in milliseconds. Determines temporal
resolution of output. Default: 1 (1 ms per output frame).
db_max (float) – Reference SPL for full-scale sinusoid in dB. Default: 93.98
(matches AMT dbspl(1) convention).
learnable (bool) – If True, frequency band limits and relevance thresholds become
learnable nn.Parameter objects. Default: False (fixed from
Moore et al. 2016/2018).
dtype (dtype) – Data type for computations. Default: torch.float32.
Raises:
ValueError – If fs!=32000. The model is specifically designed for 32 kHz
sampling rate and frequency band definitions depend on this rate.
Notes
Sampling Rate Requirement:
The 32 kHz sampling rate is strictly enforced because:
Window lengths [2048, 1024, 512, 256, 128, 64] samples correspond
to [64, 32, 16, 8, 4, 2] ms only at 32 kHz
Frequency band limits [20-80, 80-500, …] Hz are optimized for
32 kHz Nyquist frequency
Hop size of 32 samples = 1 ms temporal resolution at 32 kHz
Frequency Band Initialization:
The constructor sets up 6 frequency bands matched to window lengths:
Compute sparse multi-resolution spectrum for binaural audio.
Analyzes stereo audio through 6 overlapping FFT windows, applies frequency
band limiting, filters relevant components, and returns sparse frequency/level
pairs for each ear separately.
freqs_left (torch.Tensor) – Relevant frequency components for left ear, shape
(batch,n_segments,max_components). Frequencies in Hz. Zero
values indicate padding (no actual component).
levels_left (torch.Tensor) – SPL levels in dB for left ear components, shape
(batch,n_segments,max_components). Corresponds 1:1 with
freqs_left. Zero values indicate padding.
freqs_right (torch.Tensor) – Relevant frequency components for right ear, shape
(batch,n_segments,max_components).
levels_right (torch.Tensor) – SPL levels in dB for right ear components, shape
(batch,n_segments,max_components).
Raises:
ValueError – If audio is not 3D or does not have exactly 2 channels.
Notes
Processing Pipeline:
Split audio into left and right channels
For each channel independently:
Segment signal into 1 ms windows (hop = 32 samples)
For each segment, compute 6 FFTs with different window lengths
Combine FFTs by frequency range (20-80 Hz from longest window, etc.)
Apply relevance filtering (max-60dB and -30dB SPL thresholds)
Extract sparse frequency/level pairs
Pad sparse outputs to uniform size across all segments
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
Note
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. Dropout, BatchNorm,
etc.
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding set_extra_state() for your module
if you need to store extra state. This function is called when building the
module’s state_dict().
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
(The diagram shows an nn.ModuleA. A which has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To check whether or not we have the linear submodule, we
would call get_submodule("net_b.linear"). To check whether
we have the conv submodule, we would call
get_submodule("net_b.net_c.conv").
The runtime of get_submodule is bounded by the degree
of module nesting in target. A query against
named_modules achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, get_submodule should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by target
Raises:
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in state_dict match the keys returned by this module’s
state_dict() function. Default: True
assign (bool, optional): When set to False, the properties of the tensors
in the current module are preserved whereas setting it to True preserves
properties of the Tensors in the state dict. The only
exception is the requires_grad field of Parameter
for which the value from the module is preserved. Default: False
Returns:
NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided state_dict.
unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided state_dict.
Note:
If a parameter or buffer is registered as None and its corresponding key
exists in state_dict, load_state_dict() will raise a
RuntimeError.
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, l will be returned only once.
This is typically used to register a buffer that should not be
considered a model parameter. For example, BatchNorm’s running_mean
is not a parameter, but is part of the module’s state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting persistent to False. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module’s
state_dict.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If None, then operations
that run on buffers, such as cuda, are ignored. If None,
the buffer is not included in the module’s state_dict.
persistent (bool): whether the buffer is part of this module’s
The hook will be called every time after forward() has computed an output.
If with_kwargs is False or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after forward() is called. The hook
should have the following signature:
hook(module,args,output)->Noneormodifiedoutput
If with_kwargs is True, the forward hook will be passed the
kwargs given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided hook will be fired
before all existing forward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this torch.nn.Module. Note that global
forward hooks registered with
register_module_forward_hook() will fire before all hooks
registered by this method.
Default: False
with_kwargs (bool): If True, the hook will be passed the
kwargs given to the forward function.
Default: False
always_call (bool): If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The hook will be called every time before forward() is invoked.
If with_kwargs is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won’t be
passed to the hooks and only to the forward. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature:
hook(module,args)->Noneormodifiedinput
If with_kwargs is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing forward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this torch.nn.Module. Note that global
forward_pre hooks registered with
register_module_forward_pre_hook() will fire before all
hooks registered by this method.
Default: False
with_kwargs (bool): If true, the hook will be passed the kwargs
given to the forward function.
Default: False
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_input and grad_output are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of grad_input in
subsequent computations. grad_input will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in grad_input and grad_output will be None for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this torch.nn.Module. Note that global
backward hooks registered with
register_module_full_backward_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
The grad_output is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of grad_output in
subsequent computations. Entries in grad_output will be None for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided hook will be fired before
all existing backward_pre hooks on this
torch.nn.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this torch.nn.Module. Note that global
backward_pre hooks registered with
register_module_full_backward_pre_hook() will fire before
all hooks registered by this method.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Register a post-hook to be run after module’s load_state_dict() is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The module argument is the current module that this hook is registered
on, and the incompatible_keys argument is a NamedTuple consisting
of attributes missing_keys and unexpected_keys. missing_keys
is a list of str containing the missing keys and
unexpected_keys is a list of str containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling load_state_dict() with
strict=True are affected by modifications the hook makes to
missing_keys or unexpected_keys, as expected. Additions to either
set of keys will result in an error being thrown when strict=True, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
Set extra state contained in the loaded state_dict.
This function is called from load_state_dict() to handle any extra state
found within the state_dict. Implement this function and a corresponding
get_extra_state() for your module if you need to store extra state within its
state_dict.
Set the submodule given by target if it exists, otherwise throw an error.
Note
If strict is set to False (default), the method will replace an existing submodule
or create a new submodule if the parent module exists. If strict is set to True,
the method will only attempt to replace an existing submodule and throw an error if
the submodule does not exist.
For example, let’s say you have an nn.ModuleA that
looks like this:
(The diagram shows an nn.ModuleA. A has a nested
submodule net_b, which itself has two submodules net_c
and linear. net_c then has a submodule conv.)
To override the Conv2d with a new submodule Linear, you
could call set_submodule("net_b.net_c.conv",nn.Linear(1,1))
where strict could be True or False
To add a new submodule Conv2d to the existing net_b module,
you would call set_submodule("net_b.conv",nn.Conv2d(1,1,1)).
In the above if you set strict=True and call
set_submodule("net_b.conv",nn.Conv2d(1,1,1),strict=True), an AttributeError
will be raised because net_b does not have a submodule named conv.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
strict: If False, the method will replace an existing submodule
or create a new submodule if the parent module exists. If True,
the method will only attempt to replace an existing submodule and throw an error
if the submodule doesn’t already exist.
Raises:
ValueError: If the target string is empty or if module is not an instance of nn.Module.
AttributeError: If at any point along the path resulting from
the target string the (sub)path resolves to a non-existent
attribute name or an object that is not an instance of nn.Module.
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None are not included.
Note
The returned object is a shallow copy. It contains references
to the module’s parameters and buffers.
Warning
Currently state_dict() also accepts positional arguments for
destination, prefix and keep_vars in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Warning
Please avoid the use of argument destination as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
keep_vars (bool, optional): by default the Tensor s
returned in the state dict are detached from autograd. If it’s
set to True, detaching will not be performed.
Default: False.
Returns:
dict:
a dictionary containing a whole state of the module
Its signature is similar to torch.Tensor.to(), but only accepts
floating point or complex dtypes. In addition, this method will
only cast the floating point or complex parameters and buffers to dtype
(if given). The integral parameters and buffers will be moved
device, if that is given, but with dtypes unchanged. When
non_blocking is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note
This method modifies the module in-place.
Args:
device (torch.device): the desired device of the parameters
and buffers in this module
dtype (torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. Dropout, BatchNorm,
etc.
Args:
mode (bool): whether to set training mode (True) or evaluation
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
Note
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
Compute equivalent rectangular bandwidth (ERB) of an auditory filter.
Calculates the auditory filter bandwidth based on the center frequency using
the relationship defined by Glasberg and Moore (1990). This function computes
the bandwidth in Hz, not the ERB-rate scale value.
\[\text{BW}(f_c) = 24.7 + \frac{f_c}{9.265}\]
where \(f_c\) is the center frequency in Hz.
Parameters:
fc (Tensor) – Center frequencies in Hz. Can be scalar or any tensor shape.
Returns:
Auditory filter bandwidths in Hz. Same shape as input.
The bandwidth grows linearly with frequency above ~250 Hz. This relationship
reflects the approximately constant-Q behavior of auditory filters at high
frequencies and the approximately constant-bandwidth behavior at low frequencies.
Convert ERB-rate scale to frequency in Hz (Natural Logarithm version).
This is the inverse transformation of fc2erb(), converting from the
ERB-rate scale (Cams) back to frequency in Hz using the natural logarithm
formulation from Glasberg and Moore (1990).
Convert frequency in Hz to ERB-rate scale (Natural Logarithm version).
Transforms frequency in Hz to the ERB-rate scale (Cams) using the natural
logarithm formulation from Glasberg and Moore (1990). The ERB-rate represents
the number of equivalent rectangular bandwidths below a given frequency.
Numerical Stability: Logarithm is stable for all positive frequencies.
For fc=0, result is 0 (continuous at origin).
The ERB-rate scale provides an approximately linear representation of perceived
pitch. One unit on the ERB scale corresponds roughly to one critical bandwidth
or one auditory filter. The unit is often called “Cams” (after Cambridge).
This function uses the natural logarithm formulation. For the base-10 logarithm
version used in loudness models (Moore et al., 1997), see f2erbrate().
Calculate equivalent rectangular bandwidth (ERB) at a given frequency.
Computes the auditory filter bandwidth in Hz according to the formula from
Glasberg and Moore (1990) and Moore et al. (1997). This is the same
calculation as audfiltbw() but uses the alternative parametrization.
>>> importtorch>>> f=torch.tensor([100.0,1000.0,4000.0])>>> erb_bw=f2erb(f)>>> print(erb_bw)tensor([ 35.4939, 132.6390, 456.4560])>>> # Compare with audfiltbw (should be nearly identical)>>> fromtorch_amt.common.filterbanksimportaudfiltbw>>> print(torch.allclose(erb_bw,audfiltbw(f),atol=1e-2))True
Notes
Relation to audfiltbw: This function and audfiltbw() compute the
same quantity using slightly different parameterizations. The difference is
negligible (<0.1 Hz) across the auditory range.
The ERB bandwidth represents the width in Hz of a rectangular filter that
would pass the same total power as the rounded exponential auditory filter.
One ERB is also referred to as 1 Cam in the loudness literature.
Note: This function returns bandwidth in Hz. For the ERB-rate scale (number
of ERBs from DC), use fc2erb() or f2erbrate().
Convert frequency to ERB-rate using Base-10 Logarithm.
This is the base-10 logarithm variant used specifically in loudness models
such as Moore et al. (1997), Glasberg and Moore (2002), and Moore et al. (2016).
It differs from fc2erb() which uses natural logarithm.
Numerical Stability: Stable for all positive frequencies. The log10
function is well-behaved for the argument range encountered in auditory
applications (0.44 to 435 for 0-20 kHz).
Difference from fc2erb: This function produces slightly different values
(typically <0.05 Cams difference) compared to fc2erb() due to the
different logarithm base and constants. The base-10 version is standard in
loudness modeling.
The Cam scale (Cambridge ERB-rate scale) represents the cumulative number of
equivalent rectangular bandwidths from DC to a given frequency. The human
cochlea has approximately 40-43 Cams of frequency resolution from 20 Hz to 20 kHz.
Convert ERB-rate (Base-10 Logarithm) to frequency in Hz.
This is the inverse transformation of f2erbrate(), converting from the
ERB-rate scale (Cams) back to frequency in Hz. Used in loudness models that
employ the base-10 logarithm formulation.
Numerical Stability: The exponential operation (10^x) can produce very
large values for high ERB-rates. For erbrate=43 Cams (approximately 20 kHz),
the intermediate result is ~10^2, which is well within floating-point range.
Overflow only occurs for unrealistic ERB values >100 Cams.
This function is the inverse of f2erbrate() and is used in loudness
models to convert back from the perceptual ERB-rate scale to physical frequency.
For the natural logarithm version, see erb2fc().
Generate ERB-spaced center frequencies for auditory filterbanks.
Creates a vector of frequencies spaced equidistantly on the ERB-rate scale,
which provides perceptually uniform spacing that matches the frequency
resolution of the human auditory system. This is the standard method for
selecting center frequencies in auditory models.
\[f_i = \text{erb2fc}(\text{fc2erb}(f_{\text{low}}) + i \cdot \text{bwmul})\]
for \(i = 0, 1, 2, \ldots, N-1\) where \(N\) is determined by the
frequency range and spacing density.
Parameters:
flow (float) – Lowest center frequency in Hz. Must be positive.
fhigh (float) – Highest center frequency in Hz. Must be greater than flow.
bwmul (float) – Spacing density in ERB units. Default is 1.0 (one filter per ERB).
- bwmul < 1.0: Denser spacing (overlapping filters).
- bwmul = 1.0: Standard spacing (adjacent ERBs).
- bwmul > 1.0: Sparser spacing (gaps between filters).
basef (float | None) – Reference frequency in Hz. If provided, one filter is placed exactly at
this frequency, with others spaced symmetrically above and below.
Useful for ensuring coverage of specific frequencies of interest.
Default is None (uniform spacing from flow to fhigh).
device (device | None) – Torch device for tensor creation (CPU, CUDA, or MPS).
Default is None (uses CPU).
dtype (dtype) – Data type for output tensor. Default is torch.float32.
Returns:
Vector of center frequencies in Hz, monotonically increasing.
Shape: [num_channels] where num_channels depends on the frequency
range and bwmul.
>>> importtorch>>> fromtorch_amt.common.filterbanksimporterbspacebw>>> # Standard ERB spacing from 100 to 8000 Hz>>> fc=erbspacebw(100.0,8000.0)>>> print(f\"Number of channels: {len(fc)}\")Number of channels: 30>>> print(fc[:3])tensor([100.0000, 138.6141, 181.7625])>>>>>> # Denser spacing (2 filters per ERB)>>> fc_dense=erbspacebw(100.0,8000.0,bwmul=0.5)>>> print(f\"Denser: {len(fc_dense)} channels\")Denser: 60 channels>>>>>> # Ensure 1000 Hz is included>>> fc_1k=erbspacebw(100.0,8000.0,basef=1000.0)>>> # Note: basef may not appear exactly due to ERB discretization>>> closest_idx=torch.argmin(torch.abs(fc_1k-1000.0))>>> print(f\"Closest to 1000 Hz: {fc_1k[closest_idx]:.2f} Hz\")Closest to 1000 Hz: 1002.31 Hz>>>>>> # Use with Metal (MPS) device>>> iftorch.backends.mps.is_available():... fc_mps=erbspacebw(100.0,8000.0,device=torch.device('mps'))... print(f\"Device: {fc_mps.device}\")Device: mps:0
Notes
Typical Usage:
Speech models: 80-8000 Hz with bwmul=1.0 (~30 channels)
Music models: 20-20000 Hz with bwmul=1.0 (~43 channels)
High-resolution: Any range with bwmul=0.5 (doubles channel count)
basef Behavior: When basef is specified, the actual frequencies may
differ slightly from perfect ERB spacing near the base frequency due to
discrete channel constraints. The algorithm places one channel as close as
possible to basef and spaces others uniformly in ERB scale.
The ERB-rate scale provides perceptually uniform spacing that matches the
frequency resolution of cochlear filters. One ERB corresponds to the bandwidth
of one auditory filter, so bwmul=1.0 provides approximately critical-band spacing.