merlin.core.process module
Quantum computation processes and factories.
- class merlin.core.process.AbstractComputationProcess
Bases:
ABCAbstract base class for quantum computation processes.
- abstract compute(*args, **kwargs)
Perform the computation.
- Args:
*args: Positional arguments for the computation. **kwargs: Keyword arguments for the computation.`
- class merlin.core.process.CircuitConverter(circuit, input_specs=None, dtype=torch.complex64, device=device(type='cpu'))
Bases:
objectConvert a parameterized Perceval circuit into a differentiable PyTorch unitary matrix.
This class converts Perceval quantum circuits into PyTorch tensors that can be used in neural network training with automatic differentiation. It supports batch processing for efficient training and handles various quantum components like beam splitters, phase shifters, and unitary operations.
- Supported Components:
PS (Phase Shifter)
BS (Beam Splitter)
PERM (Permutation)
Unitary (Generic unitary matrix)
Barrier (no-op, removed during compilation)
- Attributes:
circuit: The Perceval circuit to convert param_mapping: Maps parameter names to tensor indices device: PyTorch device for tensor operations tensor_cdtype: Complex tensor dtype tensor_fdtype: Float tensor dtype
- Example:
Basic usage with a single phase shifter:
>>> import torch >>> import perceval as pcvl >>> from merlin.pcvl_pytorch.locirc_to_tensor import CircuitConverter >>> >>> # Create a simple circuit with one phase shifter >>> circuit = pcvl.Circuit(1) // pcvl.PS(pcvl.P("phi")) >>> >>> # Convert to PyTorch with gradient tracking >>> converter = CircuitConverter(circuit, input_specs=["phi"]) >>> phi_params = torch.tensor([0.5], requires_grad=True) >>> unitary = converter.to_tensor(phi_params) >>> print(unitary.shape) # torch.Size([1, 1])
Multiple parameters with grouping:
>>> # Circuit with multiple phase shifters >>> circuit = (pcvl.Circuit(2) ... // pcvl.PS(pcvl.P("theta1")) ... // (1, pcvl.PS(pcvl.P("theta2")))) >>> >>> converter = CircuitConverter(circuit, input_specs=["theta"]) >>> theta_params = torch.tensor([0.1, 0.2], requires_grad=True) >>> unitary = converter.to_tensor(theta_params) >>> print(unitary.shape) # torch.Size([2, 2])
Batch processing for training:
>>> # Batch of parameter values >>> batch_params = torch.tensor([[0.1], [0.2], [0.3]], requires_grad=True) >>> converter = CircuitConverter(circuit, input_specs=["phi"]) >>> batch_unitary = converter.to_tensor(batch_params) >>> print(batch_unitary.shape) # torch.Size([3, 1, 1])
Training integration:
>>> # Training loop with beam splitter >>> circuit = pcvl.Circuit(2) // pcvl.BS.Rx(pcvl.P("theta")) >>> converter = CircuitConverter(circuit, ["theta"]) >>> theta = torch.tensor([0.5], requires_grad=True) >>> optimizer = torch.optim.Adam([theta], lr=0.01) >>> >>> for step in range(10): ... optimizer.zero_grad() ... unitary = converter.to_tensor(theta) ... loss = some_loss_function(unitary) ... loss.backward() ... optimizer.step()
- set_dtype(dtype)
Set the tensor data types for float and complex operations.
- Args:
dtype: Target dtype (float32/complex64 or float64/complex128)
- Raises:
TypeError: If dtype is not supported
- to(dtype, device)
Move the converter to a specific device and dtype.
- Args:
dtype: Target tensor dtype (float32/complex64 or float64/complex128) device: Target device (string or torch.device)
- Returns:
Self for method chaining
- Raises:
TypeError: If device type or dtype is not supported
- to_tensor(*input_params, batch_size=None)
Convert the parameterized circuit to a PyTorch unitary tensor.
- Return type:
- Args:
*input_params: Variable number of parameter tensors. Each tensor has shape (num_params,) or (batch_size, num_params) corresponding to input_specs order. batch_size: Explicit batch size. If None, inferred from input tensors.
- Returns:
- Complex unitary tensor of shape (circuit.m, circuit.m) for single samples
or (batch_size, circuit.m, circuit.m) for batched inputs.
- Raises:
ValueError: If wrong number of input tensors provided. TypeError: If input_params is not a list or tuple.
- class merlin.core.process.Combinadics(scheme, n, m)
Bases:
objectRank/unrank Fock states in descending lexicographic order.
Parameters
- schemestr
Enumeration strategy. Supported values are
"fock","unbunched", and"dual_rail".- nint
Number of photons. Must be non-negative.
- mint
Number of modes. Must be at least one.
Raises
- ValueError
If an unsupported scheme is provided or the parameters violate the constraints of the selected scheme.
- compute_space_size()
Return the number of admissible Fock states for this configuration.
- Return type:
int
Returns
- int
Cardinality of the state space.
- enumerate_states()
Return all admissible states in descending lexicographic order.
- Return type:
list[tuple[int,...]]
Returns
- list[Tuple[int, …]]
State list matching
iter_states().
- fock_to_index(counts)
Map a Fock state to its index under the configured scheme.
- Return type:
int
Parameters
- countsIterable[int]
Photon counts per mode.
Returns
- int
Rank of the Fock state in descending lexicographic order.
Raises
- ValueError
If
countsviolates the scheme-specific constraints.
- class merlin.core.process.ComputationProcess(circuit, input_state, trainable_parameters, input_parameters, n_photons=None, reservoir_mode=False, dtype=torch.float32, device=None, computation_space=None, no_bunching=None, output_map_func=None)
Bases:
AbstractComputationProcessHandles quantum circuit computation and state evolution.
- compute_ebs_simultaneously(parameters, simultaneous_processes=1)
Evaluate a single circuit parametrisation against all superposed input states by chunking them in groups and delegating the heavy work to the TorchScript-enabled batch kernel.
The method converts the trainable parameters into a unitary matrix, normalises the input state (if it is not already normalised), filters out components with zero amplitude, and then queries the simulation graph for batches of Fock states. Each batch feeds
SLOSComputeGraph.compute_batch(), producing a tensor that contains the amplitudes of all reachable output states for the selected input components. The partial results are accumulated into a preallocated tensor and finally weighted by the complex coefficients ofself.input_stateto produce the global output amplitudes.- Return type:
- Args:
- parameters (list[torch.Tensor]): Differentiable parameters that
encode the photonic circuit. They are forwarded to
self.converterto build the unitary matrix used during the simulation.- simultaneous_processes (int): Maximum number of non-zero input
components that are propagated in a single call to
compute_batch. Tuning this value allows trading memory consumption for wall-clock time on GPU.
- Returns:
torch.Tensor: The superposed output amplitudes with shape
[batch_size, num_output_states]wherebatch_sizecorresponds to the number of independent input batches andnum_output_statesis the size ofself.simulation_graph.mapped_keys.- Raises:
TypeError: If
self.input_stateis not atorch.Tensor. The simulation graph expects tensor inputs, therefore other sequence types (NumPy arrays, lists, etc.) cannot be used here.- Notes:
self.input_stateis normalised in place to avoid an extra allocation.Zero-amplitude components are skipped to minimise the number of calls to
compute_batch.The method is agnostic to the device: tensors remain on the device they already occupy, so callers should ensure
parametersandself.input_statelive on the same device.
- compute_superposition_state(parameters, *, return_keys=False)
- compute_with_keys(parameters)
Compute quantum output distribution and return both keys and probabilities.
- configure_computation_space(computation_space=ComputationSpace.UNBUNCHED, *, validate_input=True)
Reconfigure the logical basis according to the desired computation space.
- Return type:
None
- class merlin.core.process.ComputationProcessFactory
Bases:
objectFactory for creating computation processes.
- static create(circuit, input_state, trainable_parameters, input_parameters, reservoir_mode=False, computation_space=None, **kwargs)
Create a computation process.
- Return type:
- enum merlin.core.process.ComputationSpace(value)
Bases:
str,EnumEnumeration of supported computational subspaces.
- Member Type:
str
Valid values are as follows:
- FOCK = <ComputationSpace.FOCK: 'fock'>
- UNBUNCHED = <ComputationSpace.UNBUNCHED: 'unbunched'>
- DUAL_RAIL = <ComputationSpace.DUAL_RAIL: 'dual_rail'>
The
Enumand its members also have the following methods:- classmethod default(*, no_bunching)
Derive the default computation space from the legacy no_bunching flag.
- Return type:
- classmethod coerce(value)
Normalize user-provided values (enum instances or case-insensitive strings).
- Return type:
- merlin.core.process.build_slos_distribution_computegraph(m, n_photons, output_map_func=None, computation_space=None, no_bunching=None, keep_keys=True, device=None, dtype=torch.float32, index_photons=None)
Construct a reusable SLOS computation graph.
- Return type:
Parameters
- mint
Number of modes in the circuit.
- n_photonsint
Total number of photons injected in the circuit.
- output_map_funccallable, optional
Mapping applied to each output Fock state, allowing post-processing.
computation_space : ComputationSpace, optional keep_keys : bool, optional
Whether to keep the list of mapped Fock states.
- devicetorch.device, optional
Device on which tensors should be allocated.
- dtypetorch.dtype, optional
Real dtype controlling numerical precision.
- index_photonslist[tuple[int, …]], optional
Bounds for each photon placement.
Returns
- SLOSComputeGraph
Pre-built computation graph ready for repeated evaluations.
- merlin.core.process.overload(func)
Decorator for overloaded functions/methods.
In a stub file, place two or more stub definitions for the same function in a row, each decorated with @overload.
For example:
@overload def utf8(value: None) -> None: ... @overload def utf8(value: bytes) -> bytes: ... @overload def utf8(value: str) -> bytes: ...
In a non-stub file (i.e. a regular .py file), do the same but follow it with an implementation. The implementation should not be decorated with @overload:
@overload def utf8(value: None) -> None: ... @overload def utf8(value: bytes) -> bytes: ... @overload def utf8(value: str) -> bytes: ... def utf8(value): ... # implementation goes here
The overloads for a function can be retrieved at runtime using the get_overloads() function.