merlin.utils.dtypes module

Utilities for converting between various dtype representations and torch dtypes.

merlin.utils.dtypes.to_torch_dtype(dtype_like, *, default=None)

Convert common dtype representations (strings, numpy dtypes, torch dtypes) into torch dtypes.

Parameters:
  • dtype_like (object) – Input representation to convert.

  • default (torch.dtype | None) – Fallback dtype if the representation is unknown. Defaults to torch.float32 when dtype_like is None.

Returns:

Torch dtype corresponding to the requested representation.

Return type:

torch.dtype

Raises:

TypeError – If the value cannot be mapped and no default is provided.

merlin.utils.dtypes.complex_dtype_for(dtype_like)

Return the matching complex dtype for the provided float or complex dtype.

Parameters:

dtype_like (object) – Representation of a torch dtype (string, numpy dtype, torch dtype, …).

Returns:

Torch complex dtype corresponding to the provided representation.

Return type:

torch.dtype

Raises:

TypeError – If the dtype cannot be mapped to a supported float/complex pair.

merlin.utils.dtypes.float_dtype_for(dtype_like)

Return the matching float dtype for the provided float or complex dtype.

Parameters:

dtype_like (object) – Representation of a torch dtype.

Returns:

Torch float dtype corresponding to the provided representation (string, numpy dtype, torch dtype, …).

Return type:

torch.dtype

Raises:

TypeError – If the dtype cannot be mapped to a supported float/complex pair.

merlin.utils.dtypes.resolve_float_complex(dtype)

Given a torch dtype representing either the float or complex side, return the matching pair.

Parameters:

dtype (torch.dtype) – Torch float or complex dtype.

Returns:

Matching (float_dtype, complex_dtype) pair.

Return type:

tuple[torch.dtype, torch.dtype]

Raises:

TypeError – If the dtype is not one of the supported float-complex types.