batchMatMul function

VARP batchMatMul(
  1. VARP x,
  2. VARP y, {
  3. bool adjX = false,
  4. bool adjY = false,
})

Multiplies slices of two variable in batches

Multiplies all slices of variable x and y (each slice can be viewed as an element of a batch), and arranges the individual results in a single output variable of the same batch size.

Each of the individual slices can optionally be adjointed (to adjoint a matrix means to transpose and conjugate it) before multiplication by setting the adj_x or adj_y flag to True, which are by default False.

The input variable x and y are 2-D or higher with shape ..., r_x, c_x and ..., r_y, c_y.

The output variable is 2-D or higher with shape ..., r_o, c_o, where: r_o = c_x if adj_x else r_x c_o = r_y if adj_y else c_y

It is computed as: output..., :, : = matrix(x..., :, :) * matrix(y..., :, :)

Arguments:

  • x: 2-D or higher with shape ..., r_x, c_x.
  • y: 2-D or higher with shape ..., r_y, c_y.

Optional:

  • adj_x: If True, adjoint the slices of x. Defaults to False.
  • adj_y: If True, adjoint the slices of y. Defaults to False.

Returns:

  • Output: 3-D or higher with shape ..., r_o, c_o

Implementation

VARP batchMatMul(VARP x, VARP y, {bool adjX = false, bool adjY = false}) =>
    VARP.fromPointer(C.mnn_expr_BatchMatMul(x.ptr, y.ptr, adjX, adjY));