merlin.utils.grouping module
Grouping policies.
- class merlin.utils.grouping.LexGrouping(input_size, output_size)
Bases:
ModuleMaps tensor to a lexical grouping of its components.
This mapper groups consecutive elements of the input tensor into equal-sized buckets and sums them to produce the output. If the input size is not evenly divisible by the output size, padding is applied.
- forward(x)
Map the input tensor to the desired output_size utilizing lexical grouping.
- Args:
x: Input tensor of shape (n_batch, input_size) or (input_size,)
- Returns:
Grouped tensor of shape (batch_size, output_size) or (output_size,)
- class merlin.utils.grouping.ModGrouping(input_size, output_size)
Bases:
ModuleMaps tensor to a modulo grouping of its components.
This mapper groups elements of the input tensor based on their index modulo the output size. Elements with the same modulo value are summed together to produce the output.
- forward(x)
Map the input tensor to the desired output_size utilizing modulo grouping.
- Args:
x: Input tensor of shape (n_batch, input_size) or (input_size,)
- Returns:
Grouped tensor of shape (batch_size, output_size) or (output_size,)