reduce_mul_op
Structs
Struct: ReduceMul
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after reducing along a specific axis.
Args
-
curr
:ArrayShape
The ArrayShape to store the result of the computation. -
args
:List[ArrayShape]
The ArrayShape to reduce, and the axis to reduce along encoded in an ArrayShape.
Constraints:
- The axis must be a valid axis of the ArrayShape (args[0]).
- The number of axis must not exceed the number of dimensions of the ArrayShape (args[0]).
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for element-wise mulition of two arrays.
Args
-
curr
:Array
The current array to store the result (modified in-place). -
args
:List[Array]
A list containing the input arrays.
Computes the sum of the input arrays and stores the result in the current array. Initializes the current array if not already set up.
Note: This function assumes that the shape and data of the args are already set up. If the current array (curr) is not initialized, it computes the shape based on the input array and the axis and sets up the data accordingly.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the mulition function.
Args
-
primals
:List[Array]
A list containing the primal input arrays. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass (unused in this function).
Returns
List[Array]
- A list containing the gradient with respect to the input.
Implements reverse-mode automatic differentiation for the mulition function.
Note: The vector-Jacobian product for the mulition is computed as the gradient itself.
fwd(arg0: Array, axis: List[Int]) -> Array
Reduces the input array along the specified axis by summing the elements.
Args
-
arg0
:Array
The input array. -
axis
:List[Int]
The axis along which to reduce the array.
Returns
Array
- An array containing the sum of the input array along the specified axis.
Examples:
a = Array([[1, 2], [3, 4]])
result = reduce_mul(a, List(0))
print(result)
Note: This function supports:
- Automatic differentiation (forward and reverse modes).
- Complex valued arguments.
Functions
reduce_mul
reduce_mul(arg0: Array, axis: List[Int]) -> Array
Reduces the input array along the specified axis by summing the elements.
Args
-
arg0
:Array
The input array. -
axis
:List[Int]
The axis along which to reduce the array.
Returns
Array
- An array containing the sum of the input array along the specified axis.
Examples:
a = Array([[1, 2], [3, 4]])
result = reduce_mul(a, List(0))
print(result)
Note: This function supports:
- Automatic differentiation (forward and reverse modes).
- Complex valued arguments.