merlin.algorithms.feed_forward_legacy module
Note
This module exposes the legacy feed-forward implementation kept for
backward compatibility. For new code, prefer
merlin.algorithms.feed_forward and
FeedForwardBlock.
- merlin.algorithms.feed_forward_legacy.create_circuit(M, input_size)
Create a quantum photonic circuit with beam splitters and phase shifters.
- Parameters:
- Returns:
A quantum photonic circuit with alternating beam splitter layers and phase shifters.
- Return type:
- merlin.algorithms.feed_forward_legacy.define_layer_no_input(n_modes, n_photons, circuit_type=None)
Define a quantum layer for feed-forward processing.
- Parameters:
- Returns:
A configured quantum layer with trainable parameters.
- Return type:
- merlin.algorithms.feed_forward_legacy.define_layer_with_input(M, N, input_size, circuit_type=None)
Define the first layers of the feed-forward block, those with an input size > 0.
- Parameters:
- Returns:
The first quantum layer with input parameters.
- Return type:
- class merlin.algorithms.feed_forward_legacy.FeedForwardBlockLegacy(input_size, n, m, depth=None, state_injection=False, conditional_modes=None, layers=None, circuit_type=None, device=None)
Bases:
ModuleFeed-forward quantum neural network for photonic computation.
This class models a conditional feed-forward architecture used in quantum photonic circuits. It connects multiple quantum layers in a branching tree structure — where each branch corresponds to a sequence of photon-detection outcomes on designated conditional modes.
Each node in this feedforward tree represents a QuantumLayer that acts on a quantum state conditioned on measurement results of previous layers.
The recursion continues until a specified depth, allowing the model to simulate complex conditional evolution of quantum systems.
Detector support: The current feed-forward implementation expects amplitude access for every intermediate layer (
MeasurementStrategy.AMPLITUDES) and therefore assumes ideal PNR detectors. Custom detector transforms or Perceval experiments with threshold / hybrid detectors are not yet supported inside this block.- Parameters:
input_size (int) – Number of classical input features used for hybrid quantum-classical computation.
n (int) – Number of photons in the system.
m (int) – Total number of photonic modes.
depth (int | None) – Maximum depth of feed-forward recursion. Defaults to
m - 1if not specified.state_injection (bool) – If
True, allows re-injecting quantum states at intermediate steps, which is useful for simulating sources or ancilla modes. Defaults toFalse.conditional_modes (list[int] | None) – List of mode indices on which photon detection is performed. This determines the branching structure. Defaults to
[0].layers (list | None) – Predefined list of quantum layers. If not provided, layers are generated automatically.
circuit_type (str | None) – Type of quantum circuit architecture used to build each layer. Acts as a template selector for circuit structure generation.
device (torch.device | str | None) – Target device for the module and all generated layers.
- define_ff_layer(k, layers)
Replace quantum layers at a specific depth
k.- Parameters:
k (int) – Feed-forward layer depth index.
layers (list[QuantumLayer]) – List of replacement layers.
- Raises:
AssertionError – If
layersdoes not have the expected length.
- define_layers(circuit_type)
Define and instantiate all quantum layers for each measurement outcome path.
Each tuple (representing a branch of the feedforward tree) is mapped to a
QuantumLayerobject. Depending on whether the state injection mode is active, the number of modes/photons and the input size differ.- Parameters:
circuit_type (str | None) – Template name or circuit architecture type.
- Raises:
AssertionError – If total input size does not match after allocation.
- Return type:
- forward(x)
Perform the full quantum-classical feedforward computation.
- Parameters:
x (torch.Tensor) – Classical input tensor of shape
(batch_size, input_size).- Returns:
Final output tensor containing probabilities for each terminal measurement configuration.
- Return type:
- Raises:
ValueError – If the trailing input dimension does not match
self.input_size.
- generate_possible_tuples()
Generate all possible conditional outcome tuples.
Each tuple represents one possible sequence of photon detection results across all conditional modes up to a given depth. For example, with
n_cond = 2anddepth = 3, tuples correspond to binary sequences of lengthdepth * n_cond.
- get_output_size()
Compute the number of output channels (post-measurement outcomes).
- input_size_ff_layer(k)
Return the list of input sizes for all layers at depth k.
- iterate_feedforward(current_tuple, remaining_amplitudes, keys, accumulated_prob, intermediary, outputs, depth=0, x=None)
Recursive feedforward traversal of the quantum circuit tree.
- At each step:
Evaluate photon detection outcomes (0/1) on conditional modes.
For each possible combination, compute probabilities.
Apply the corresponding quantum layer and recurse deeper.
- Parameters:
current_tuple (tuple[int, ...]) – Current measurement sequence path.
remaining_amplitudes (torch.Tensor) – Quantum amplitudes of current state.
keys (list[tuple[int, ...]]) – Fock basis keys for amplitudes.
accumulated_prob (torch.Tensor | float) – Product of probabilities so far.
intermediary (dict) – Stores intermediate probabilities.
outputs (dict) – Stores final output probabilities for all branches.
depth (int) – Current recursion depth. Default is 0.
x (torch.Tensor | None) – Classical input features.
- Return type:
- property output_keys
Return cached output keys, or compute them via a dummy forward pass.
- parameters()
Iterate over all trainable parameters from every quantum layer.
- size_ff_layer(k)
Return number of feed-forward branches at layer depth k.
- to(device)
Move the block and all QuantumLayers to the specified device.
- Parameters:
device (str | torch.device) – Target device (
"cpu","cuda","mps", etc.).- Returns:
selfon the requested device.- Return type:
- class merlin.algorithms.feed_forward_legacy.PoolingFeedForwardLegacy(n_modes, n_photons, n_output_modes, pooling_modes=None, no_bunching=None)
Bases:
ModuleA quantum-inspired pooling module that aggregates amplitude information from an input quantum state representation into a lower-dimensional output space.
This module computes mappings between input and output Fock states (defined by keys_in and keys_out) based on a specified pooling scheme. It then aggregates the amplitudes according to these mappings, normalizing the result to preserve probabilistic consistency.
- Parameters:
n_modes (int) – Number of input modes in the quantum circuit.
n_photons (int) – Number of photons used in the quantum simulation.
n_output_modes (int) – Number of output modes after pooling.
pooling_modes (list[list[int]] | None) – Specifies how input modes are grouped (pooled) into output modes. Each sublist contains the indices of input modes to pool together for one output mode. If None, an even pooling scheme is automatically generated.
no_bunching (bool | None) – Deprecated and now removed; use computation_space in MeasurementStrategy instead.
- match_indices
torch.Tensor containing the indices mapping input states to output states.
- Type:
- exclude_indices
torch.Tensor containing indices of input states that have no valid mapping to an output state.
- Type:
- forward(amplitudes)
Forward pass that pools input quantum amplitudes into output modes.
- Parameters:
amplitudes (torch.Tensor) – Input tensor of shape (batch_size, n_input_states) containing the complex amplitudes (or real/imag parts) of quantum states.
- Returns:
Normalized pooled amplitudes of shape (batch_size, n_output_states).
- Return type:
- match_tuples(keys_in, keys_out, pooling_modes)
Match input and output Fock state tuples based on pooling configuration.
For each input Fock state (
key_in), the corresponding pooled output state (key_out) is computed by summing the photon counts over each pooling group. Input states that do not correspond to a valid output state are marked for exclusion.- Parameters:
- Returns:
A pair
(indices, exclude_indices)whereindicesare the matched indices from input to output keys, andexclude_indicesare input indices with no valid match.- Return type: