merlin.core.state_vector

StateVector

class merlin.core.state_vector.StateVector(tensor, n_modes, n_photons, _normalized=False)

Bases: object

Amplitude tensor bundled with its Fock metadata.

Keeps n_modes / n_photons and combinadics basis ordering alongside the underlying PyTorch tensor (dense or sparse).

Parameters

tensor:

Dense or sparse amplitude tensor; leading dimensions (if any) are treated as batch axes.

n_modes:

Number of modes in the Fock space.

n_photons:

Total photon number represented by the state.

_normalized:

Internal flag tracking whether the stored tensor is normalized.

Notes

This is a thin wrapper over a torch.Tensor: only shape, device, dtype, and requires_grad are delegated automatically, and tensor-like helpers to, clone, detach, and requires_grad_ are provided to mirror common tensor workflows while preserving metadata. Layout-changing operations (e.g., reshape/view) are intentionally not exposed; perform those on tensor explicitly if needed and rebuild via from_tensor.

tensor: Tensor
n_modes: int
n_photons: int
property is_normalized: bool
property basis: Combinadics

Lazy combinadics basis for (n_modes, n_photons) in Fock ordering.

property is_sparse: bool

Return True if the underlying tensor uses a sparse layout.

property basis_size: int

Return the number of basis states for (n_modes, n_photons).

to(*args, **kwargs)

Return a new StateVector with the tensor moved/cast via torch.Tensor.to.

Return type:

StateVector

clone()

Return a cloned StateVector with identical metadata and normalization flag.

Return type:

StateVector

detach()

Return a detached StateVector sharing data without gradients.

Return type:

StateVector

requires_grad_(requires_grad=True)

Set requires_grad on the underlying tensor and return self.

Return type:

StateVector

memory_bytes()

Approximate memory footprint (bytes) of the underlying tensor data.

Return type:

int

to_perceval()

Convert to pcvl.StateVector.

Return type:

StateVector | list[StateVector]

Args:

None

Returns:

pcvl.StateVector | list[pcvl.StateVector]: A Perceval state for 1D tensors, or a list for batched tensors, with amplitudes preserved (no extra renormalization).

classmethod from_perceval(state_vector, *, dtype=None, device=None, sparse=None)

Build from a pcvl.StateVector.

Return type:

StateVector

Args:

state_vector: Perceval state to wrap. dtype: Optional target dtype. device: Optional target device. sparse: Force sparse/dense; if None use density heuristic (<=30%).

Returns:

StateVector: Merlin wrapper with metadata and preserved amplitudes.

Raises:

ValueError: If the Perceval state is empty or has inconsistent photon/mode counts.

classmethod from_basic_state(state, *, dtype=None, device=None, sparse=True)

Create a one-hot state from a Fock occupation list/BasicState.

Return type:

StateVector

Args:

state: Occupation numbers per mode. dtype: Optional target dtype. device: Optional target device. sparse: Build sparse layout when True.

Returns:

StateVector: One-hot state.

classmethod from_tensor(tensor, *, n_modes, n_photons, dtype=None, device=None)

Wrap an existing tensor with explicit metadata.

Return type:

StateVector

Args:

tensor: Dense or sparse amplitude tensor. n_modes: Number of modes. n_photons: Total photons. dtype: Optional target dtype. device: Optional target device.

Returns:

StateVector: Wrapped tensor.

Raises:

ValueError: If the last dimension does not match the basis size.

tensor_product(other, *, sparse=None)

Tensor product of two states with metadata propagation.

If any operand is dense, the result is dense. Supports one-hot fast path. The resulting state is normalized before returning.

Return type:

StateVector

Args:

other: Another StateVector or a BasicState/occupation list. sparse: Override sparsity of the result; default keeps dense if any input dense.

Returns:

StateVector: Combined state with summed modes/photons (normalized).

Raises:

ValueError: If tensors are not 1D.

index(state)

Return basis index for the given Fock state.

Return type:

Optional[int]

Args:

state: Occupation list or BasicState.

Returns:

int | None: Basis index, or None if not present (or zero in sparse tensor).

to_dense()

Return a dense, normalized tensor view of the amplitudes.

Return type:

Tensor

normalize()

Normalize this state in-place and return self.

Return type:

StateVector

normalized_str()

Human-friendly string of the normalized state (forces normalization for display).

Return type:

str

Notes and Examples

Constructors

from_basic_state — one-hot Fock state (sparse by default):

from merlin.core.state_vector import StateVector

sv = StateVector.from_basic_state([1, 0, 1, 0])
assert sv.is_sparse
assert sv.n_modes == 4 and sv.n_photons == 2

# Dense variant
sv_dense = StateVector.from_basic_state([1, 0, 1, 0], sparse=False)
assert not sv_dense.is_sparse

from_tensor — wrap a real or complex tensor with Fock metadata. Real data is auto-promoted to complex. The last dimension must match the basis size \(\binom{n\_modes + n\_photons - 1}{n\_photons}\):

import torch
from merlin.core.state_vector import StateVector

# Single sample (1-D)
features = torch.randn(10)
sv = StateVector.from_tensor(features, n_modes=4, n_photons=2)

# Batched (2-D) — leading dimensions are batch axes
batch = torch.randn(32, 10)
sv_batch = StateVector.from_tensor(batch, n_modes=4, n_photons=2)
assert sv_batch.shape == (32, 10)

from_perceval — convert from a Perceval StateVector:

import perceval as pcvl
from merlin.core.state_vector import StateVector

pv = pcvl.StateVector(pcvl.BasicState([1, 0]))
sv = StateVector.from_perceval(pv)

# Round-trip
pv_back = sv.to_perceval()

Properties and metadata

n_modes and n_photons are set at construction and immutable:

sv = StateVector.from_basic_state([1, 0, 1, 0])
sv.n_modes    # 4
sv.n_photons  # 2
sv.n_modes = 5  # raises AttributeError

shape, device, dtype, and requires_grad are delegated to the underlying tensor:

sv.shape          # torch.Size([10])  — basis_size for (4, 2)
sv.device         # device(type='cpu')
sv.dtype          # torch.complex64

basis returns the combinadics Fock ordering for (n_modes, n_photons). basis_size is equivalent to len(basis):

sv.basis_size     # 10 for (4 modes, 2 photons)
list(sv.basis)[:3]  # [(2,0,0,0), (1,1,0,0), (1,0,1,0)]

Amplitude lookup

Use bracket syntax with an occupation list or pcvl.BasicState:

import perceval as pcvl

sv = StateVector.from_basic_state([1, 0, 1, 0], sparse=False)
amp = sv[[1, 0, 1, 0]]                    # complex scalar
amp = sv[pcvl.BasicState([1, 0, 1, 0])]   # equivalent

For batched states, the returned tensor matches the batch shape:

import torch

batch = torch.randn(8, 10, dtype=torch.complex64)
sv = StateVector.from_tensor(batch, n_modes=4, n_photons=2)
amps = sv[[1, 0, 1, 0]]   # shape: (8,)

index(state) returns the integer basis index (or None if the state is absent in a sparse tensor):

sv.index([1, 0, 1, 0])  # e.g. 2

Superpositions and arithmetic

Addition and subtraction require matching n_modes and n_photons. Results are not automatically normalized — call .normalize() explicitly:

a = StateVector.from_basic_state([1, 0], sparse=False)
b = StateVector.from_basic_state([0, 1], sparse=False)

superposed = (a + b).normalize()       # (|1,0⟩ + |0,1⟩) / √2
diff       = (a - b).normalize()       # (|1,0⟩ - |0,1⟩) / √2
scaled     = 0.5 * a                   # unnormalized until .normalize()

normalize() acts in-place and returns self.

Tensor product (tensor_product or @) combines two sub-systems:

left  = StateVector.from_basic_state([1, 0], sparse=False)
right = StateVector.from_basic_state([0, 1], sparse=True)
combined = left.tensor_product(right)  # or: left @ right
assert combined.n_modes == 4 and combined.n_photons == 2

Dense tensor access

to_dense() returns a normalized, dense torch.Tensor. Sparse states are materialized; already-dense states are returned directly:

sv = StateVector.from_basic_state([1, 0, 1, 0])
dense = sv.to_dense()   # shape: (10,), complex, sum of |amplitudes|^2 == 1

PyTorch-like helpers

to, clone, detach, and requires_grad_ mirror the standard torch.Tensor API while preserving Fock metadata:

sv = StateVector.from_basic_state([1, 0], sparse=False)

sv_cuda = sv.to("cuda")               # moves tensor, preserves n_modes/n_photons
sv_copy = sv.clone()                   # independent copy with same metadata
sv_det  = sv.detach()                  # shares data, no gradient graph
sv.requires_grad_(True)               # enable gradients in-place; returns self

QuantumLayer integration

As input_state — sets the initial photon configuration:

import merlin as ML
from merlin.core.state_vector import StateVector

layer = ML.QuantumLayer(
    builder=builder,
    input_state=StateVector.from_basic_state([1, 0, 1, 0]),
    n_photons=2,
    measurement_strategy=ML.MeasurementStrategy.probs(ML.ComputationSpace.FOCK),
)

As input to forward() — activates amplitude encoding with classical data:

import torch
from merlin.core.state_vector import StateVector

features = torch.randn(32, len(layer.output_keys))
sv = StateVector.from_tensor(features, n_modes=4, n_photons=2)
output = layer(sv)   # shape: (32, output_size)

As output from forward() — with MeasurementStrategy.amplitudes(ComputationSpace.FOCK) and return_object=True:

layer = ML.QuantumLayer(
    builder=builder,
    n_photons=2,
    measurement_strategy=ML.MeasurementStrategy.amplitudes(ML.ComputationSpace.FOCK),
    return_object=True,
)
sv_out = layer(sv)               # StateVector
sv_out[[1, 0, 1, 0]]            # amplitude lookup on the output

Perceval interoperability

Round-trip between Merlin and Perceval representations:

import perceval as pcvl
from merlin.core.state_vector import StateVector

# Perceval → Merlin
pcvl_sv = (
    pcvl.StateVector(pcvl.BasicState([1, 0, 1, 0]))
    + pcvl.StateVector(pcvl.BasicState([0, 1, 0, 1]))
)
sv = StateVector.from_perceval(pcvl_sv)

# Merlin → Perceval
pv_back = sv.to_perceval()