merlin.core.partial_measurement

PartialMeasurementBranch

class merlin.core.partial_measurement.PartialMeasurementBranch(outcome, probability, amplitudes)

Bases: object

Single branch of a partial measurement for a specific measured-mode outcome.

Parameters:

PartialMeasurement

class merlin.core.partial_measurement.PartialMeasurement(branches, measured_modes, unmeasured_modes, grouping=None)

Bases: object

Collection of partial-measurement branches and mode metadata.

Parameters:
  • branches (tuple[PartialMeasurementBranch, ...]) – Branches ordered lexicographically by outcome.

  • measured_modes (tuple[int, ...]) – Indices of measured modes in the full system.

  • unmeasured_modes (tuple[int, ...]) – Indices of unmeasured modes in the full system.

  • grouping (Callable[[torch.Tensor], torch.Tensor] | None) – Optional callable used to group branch probabilities.

property amplitudes

Conditional amplitudes for each branch.

Type:

list[merlin.core.state_vector.StateVector]

static from_detector_transform_output(detector_output, *, grouping=None)

Branch-based PartialMeasurement wrapper from DetectorTransform(partial_measurement=True) output.

Parameters:
Returns:

Branch-based partial-measurement wrapper.

Return type:

PartialMeasurement

property n_measured_modes: int

Number of measured modes.

Type:

int

property n_unmeasured_modes: int

Number of unmeasured modes.

Type:

int

property outcomes

Measured outcomes for each branch.

Type:

list[tuple[int, …]]

property probabilities: Tensor

Alias for tensor.

Type:

torch.Tensor

property probability_tensor_shape: tuple[int, int]

Return the expected (batch, n_outcomes) shape for the probability tensor.

reorder_branches()

Reorder branches lexicographically by their outcomes.

Return type:

None

set_grouping(grouping)

Set the grouping used to aggregate probabilities.

Parameters:

grouping (Callable[[torch.Tensor], torch.Tensor] | None) – Callable used to group branch probabilities.

Raises:

TypeError – If grouping is not callable.

Return type:

None

property tensor: Tensor

Returns branch probabilities as a stacked tensor. This property assumes that all branches are ordered lexicographically by their outcomes so the stacking of probabilities follows the same order.

Returns:

Tensor of shape (batch, n_branches). If a grouping is set, the returned tensor has shape (batch, grouping_output_size).

Return type:

torch.Tensor

verify_branches_order()

Verify that branches are ordered lexicographically by their outcomes.

Return type:

None

merlin.core.partial_measurement.DetectorTransformOutput

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

alias of list[dict[tuple[int | None, …], list[tuple[Tensor, Tensor]]]]

Notes and Examples

Basic Structure

A PartialMeasurement represents the outcome of measuring a subset of modes in a quantum system. It consists of a collection of PartialMeasurementBranch objects, each corresponding to a specific measurement outcome on the measured modes, along with the conditional quantum state on the unmeasured modes.

from merlin.core.partial_measurement import PartialMeasurement, PartialMeasurementBranch
from merlin.core.state_vector import StateVector
import torch

# Create branches manually
outcome_1 = (1, 0)  # measurement result on measured modes 0 and 1
prob_1 = torch.tensor([0.5])  # probability for this outcome
amps_1 = StateVector.from_basic_state([1, 0], sparse=False)  # conditional state on unmeasured modes

branch_1 = PartialMeasurementBranch(outcome_1, prob_1, amps_1)
branch_2 = PartialMeasurementBranch((0, 1), torch.tensor([0.5]), StateVector.from_basic_state([0, 1], sparse=False))

# Combine into PartialMeasurement
pm = PartialMeasurement(
    branches=(branch_1, branch_2),
    measured_modes=(0, 1),
    unmeasured_modes=(2, 3)
)

Accessing Measurement Results

The PartialMeasurement class provides convenient access to the measurement outcomes and their associated probabilities.

# Access properties
print(f"Measured modes: {pm.measured_modes}")
print(f"Unmeasured modes: {pm.unmeasured_modes}")
print(f"Number of branches: {len(pm.branches)}")

# Get probability distribution across all outcomes
prob_tensor = pm.tensor  # shape: (batch_size, n_branches)

# Get individual outcomes and amplitudes
for outcome, branch in zip(pm.outcomes, pm.branches):
    print(f"Outcome {outcome}: probability={branch.probability}, amplitude shape={branch.amplitudes.shape}")

Working with Grouped Probabilities

The PartialMeasurement supports optional grouping of probabilities, which allows you to aggregate outcomes according to a custom grouping function (e.g., for classical post-processing or symmetry-based grouping).

from merlin.utils.grouping import ModGrouping

# Define a grouping function
grouping = ModGrouping(input_size = 4,output_size=2)  # example grouping

# Apply grouping to PartialMeasurement
pm.set_grouping(grouping)

# Grouped probabilities now have shape (batch_size, output_size) instead of (batch_size, n_branches)
grouped_probs = pm.probabilities
print(grouped_probs.shape)  # (batch_size, 2)

Creating from DetectorTransform Output

PartialMeasurement objects are typically created from the output of a detector transformation in the measurement pipeline.

from merlin.core.partial_measurement import PartialMeasurement

# When DetectorTransform produces partial measurement output (partial_measurement=True),
# it returns a structure that can be converted to PartialMeasurement
detector_output = [...]  # output from DetectorTransform

pm = PartialMeasurement.from_detector_transform_output(
    detector_output,
    grouping=None
)

Batch Processing

Probabilities are stored per-batch in each branch, allowing for efficient handling of batch-processed quantum circuits.

import torch

# Batch-wise probabilities
batch_probs = torch.tensor([[0.3, 0.7], [0.6, 0.4]])  # shape: (batch_size=2, n_outcomes=2)

# When creating a branch with batched probabilities
branch = PartialMeasurementBranch(
    outcome=(1, 0),
    probability=batch_probs[:, 0],  # shape: (batch_size,)
    amplitudes=StateVector.from_basic_state([1, 0])
)