batchMatMul function
Multiplies slices of two variable in batches
Multiplies all slices of variable x and y (each slice can be viewed as an element of a batch), and arranges the individual results in a single output variable of the same batch size.
Each of the individual slices can optionally be adjointed (to adjoint a matrix means to transpose and conjugate it) before multiplication by setting the adj_x or adj_y flag to True, which are by default False.
The input variable x and y are 2-D or higher with shape ..., r_x, c_x and ..., r_y, c_y.
The output variable is 2-D or higher with shape ..., r_o, c_o, where:
r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y
It is computed as:
output..., :, : = matrix(x..., :, :) * matrix(y..., :, :)
Arguments:
- x: 2-D or higher with shape
..., r_x, c_x. - y: 2-D or higher with shape
..., r_y, c_y.
Optional:
- adj_x: If True, adjoint the slices of x. Defaults to False.
- adj_y: If True, adjoint the slices of y. Defaults to False.
Returns:
- Output: 3-D or higher with shape
..., r_o, c_o
Implementation
VARP batchMatMul(VARP x, VARP y, {bool adjX = false, bool adjY = false}) =>
VARP.fromPointer(C.mnn_expr_BatchMatMul(x.ptr, y.ptr, adjX, adjY));