slp: support complex multiply and complex multiply conjugate

This adds support for complex multiply and complex multiply and accumulate to
the vect pattern detector.

Example of instructions matched:

#include <stdio.h>
#include <complex.h>

#define N 200
#define ROT
#define TYPE float
#define TYPE2 float

void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  a[i] * (b[i] ROT);
    }
}

void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  conjf (a[i]) * (b[i] ROT);
    }
}

void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] =  a[i] * conjf (b[i] ROT);
    }
}

gcc/ChangeLog:

	* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
	* optabs.def (cmul_optab, cmul_conj_optab): New.
	* doc/md.texi: Document them.
	* tree-vect-slp-patterns.c (vect_match_call_complex_mla,
	vect_normalize_conj_loc, is_eq_or_top, vect_validate_multiplication,
	vect_build_combine_node, class complex_mul_pattern,
	complex_mul_pattern::matches, complex_mul_pattern::recognize,
	complex_mul_pattern::build): New.
This commit is contained in:
Tamar Christina 2021-01-14 20:57:17 +00:00
parent 500600c784
commit e09173d84d
4 changed files with 416 additions and 0 deletions

View File

@ -6202,6 +6202,50 @@ The operation is only supported for vector modes @var{m}.
This pattern is not allowed to @code{FAIL}.
@cindex @code{cmul@var{m}4} instruction pattern
@item @samp{cmul@var{m}4}
Perform a vector multiply that is semantically the same as multiply of
complex numbers.
@smallexample
complex TYPE c[N];
complex TYPE a[N];
complex TYPE b[N];
for (int i = 0; i < N; i += 1)
@{
c[i] = a[i] * b[i];
@}
@end smallexample
In GCC lane ordering the real part of the number must be in the even lanes with
the imaginary part in the odd lanes.
The operation is only supported for vector modes @var{m}.
This pattern is not allowed to @code{FAIL}.
@cindex @code{cmul_conj@var{m}4} instruction pattern
@item @samp{cmul_conj@var{m}4}
Perform a vector multiply by conjugate that is semantically the same as a
multiply of complex numbers where the second multiply arguments is conjugated.
@smallexample
complex TYPE c[N];
complex TYPE a[N];
complex TYPE b[N];
for (int i = 0; i < N; i += 1)
@{
c[i] = a[i] * conj (b[i]);
@}
@end smallexample
In GCC lane ordering the real part of the number must be in the even lanes with
the imaginary part in the odd lanes.
The operation is only supported for vector modes @var{m}.
This pattern is not allowed to @code{FAIL}.
@cindex @code{ffs@var{m}2} instruction pattern
@item @samp{ffs@var{m}2}
Store into operand 0 one plus the index of the least significant 1-bit

View File

@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
/* FP scales. */

View File

@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3")
OPTAB_D (xorsign_optab, "xorsign$F$a3")
OPTAB_D (cadd90_optab, "cadd90$a3")
OPTAB_D (cadd270_optab, "cadd270$a3")
OPTAB_D (cmul_optab, "cmul$a3")
OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
OPTAB_D (cos_optab, "cos$a2")
OPTAB_D (cosh_optab, "cosh$a2")
OPTAB_D (exp10_optab, "exp10$a2")

View File

@ -717,6 +717,374 @@ complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
return new complex_add_pattern (node, &ops, ifn);
}
/*******************************************************************************
* complex_mul_pattern
******************************************************************************/
/* Helper function of that looks for a match in the CHILDth child of NODE. The
child used is stored in RES.
If the match is successful then ARGS will contain the operands matched
and the complex_operation_t type is returned. If match is not successful
then CMPLX_NONE is returned and ARGS is left unmodified. */
static inline complex_operation_t
vect_match_call_complex_mla (slp_tree node, unsigned child,
vec<slp_tree> *args = NULL, slp_tree *res = NULL)
{
gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
slp_tree data = SLP_TREE_CHILDREN (node)[child];
if (res)
*res = data;
return vect_detect_pair_op (data, false, args);
}
/* Check to see if either of the trees in ARGS are a NEGATE_EXPR. If the first
child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE.
If a negate is found then the values in ARGS are reordered such that the
negate node is always the second one and the entry is replaced by the child
of the negate node. */
static inline bool
vect_normalize_conj_loc (vec<slp_tree> args, bool *neg_first_p = NULL)
{
gcc_assert (args.length () == 2);
bool neg_found = false;
if (vect_match_expression_p (args[0], NEGATE_EXPR))
{
std::swap (args[0], args[1]);
neg_found = true;
if (neg_first_p)
*neg_first_p = true;
}
else if (vect_match_expression_p (args[1], NEGATE_EXPR))
{
neg_found = true;
if (neg_first_p)
*neg_first_p = false;
}
if (neg_found)
args[1] = SLP_TREE_CHILDREN (args[1])[0];
return neg_found;
}
/* Helper function to check if PERM is KIND or PERM_TOP. */
static inline bool
is_eq_or_top (complex_load_perm_t perm, complex_perm_kinds_t kind)
{
return perm.first == kind || perm.first == PERM_TOP;
}
/* Helper function that checks to see if LEFT_OP and RIGHT_OP are both MULT_EXPR
nodes but also that they represent an operation that is either a complex
multiplication or a complex multiplication by conjugated value.
Of the negation is expected to be in the first half of the tree (As required
by an FMS pattern) then NEG_FIRST is true. If the operation is a conjugate
operation then CONJ_FIRST_OPERAND is set to indicate whether the first or
second operand contains the conjugate operation. */
static inline bool
vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
vec<slp_tree> left_op, vec<slp_tree> right_op,
bool neg_first, bool *conj_first_operand,
bool fms)
{
/* The presence of a negation indicates that we have either a conjugate or a
rotation. We need to distinguish which one. */
*conj_first_operand = false;
complex_perm_kinds_t kind;
/* Complex conjugates have the negation on the imaginary part of the
number where rotations affect the real component. So check if the
negation is on a dup of lane 1. */
if (fms)
{
/* Canonicalization for fms is not consistent. So have to test both
variants to be sure. This needs to be fixed in the mid-end so
this part can be simpler. */
kind = linear_loads_p (perm_cache, right_op[0]).first;
if (!((kind == PERM_ODDODD
&& is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
PERM_ODDEVEN))
|| (kind == PERM_ODDEVEN
&& is_eq_or_top (linear_loads_p (perm_cache, right_op[1]),
PERM_ODDODD))))
return false;
}
else
{
if (linear_loads_p (perm_cache, right_op[1]).first != PERM_ODDODD
&& !is_eq_or_top (linear_loads_p (perm_cache, right_op[0]),
PERM_ODDEVEN))
return false;
}
/* Deal with differences in indexes. */
int index1 = fms ? 1 : 0;
int index2 = fms ? 0 : 1;
/* Check if the conjugate is on the second first or second operand. The
order of the node with the conjugate value determines this, and the dup
node must be one of lane 0 of the same DR as the neg node. */
kind = linear_loads_p (perm_cache, left_op[index1]).first;
if (kind == PERM_TOP)
{
if (linear_loads_p (perm_cache, left_op[index2]).first == PERM_EVENODD)
return true;
}
else if (kind == PERM_EVENODD)
{
if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD)
return false;
}
else if (!neg_first)
*conj_first_operand = true;
else
return false;
if (kind != PERM_EVENEVEN)
return false;
return true;
}
/* Helper function to help distinguish between a conjugate and a rotation in a
complex multiplication. The operations have similar shapes but the order of
the load permutes are different. This function returns TRUE when the order
is consistent with a multiplication or multiplication by conjugated
operand but returns FALSE if it's a multiplication by rotated operand. */
static inline bool
vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
vec<slp_tree> op, complex_perm_kinds_t permKind)
{
/* The left node is the more common case, test it first. */
if (!is_eq_or_top (linear_loads_p (perm_cache, op[0]), permKind))
{
if (!is_eq_or_top (linear_loads_p (perm_cache, op[1]), permKind))
return false;
}
return true;
}
/* This function combines two nodes containing only even and only odd lanes
together into a single node which contains the nodes in even/odd order
by using a lane permute.
The lanes in EVEN and ODD are duplicated 2 times inside the vectors.
So for a lanes = 4 EVEN contains {EVEN1, EVEN1, EVEN2, EVEN2}.
The tree REPRESENTATION is taken from the supplied REP along with the
vectype which must be the same between all three nodes.
*/
static slp_tree
vect_build_combine_node (slp_tree even, slp_tree odd, slp_tree rep)
{
vec<std::pair<unsigned, unsigned> > perm;
perm.create (SLP_TREE_LANES (rep));
for (unsigned x = 0; x < SLP_TREE_LANES (rep); x+=2)
{
perm.quick_push (std::make_pair (0, x));
perm.quick_push (std::make_pair (1, x+1));
}
slp_tree vnode = vect_create_new_slp_node (2, SLP_TREE_CODE (even));
SLP_TREE_CODE (vnode) = VEC_PERM_EXPR;
SLP_TREE_LANE_PERMUTATION (vnode) = perm;
SLP_TREE_CHILDREN (vnode).create (2);
SLP_TREE_CHILDREN (vnode).quick_push (even);
SLP_TREE_CHILDREN (vnode).quick_push (odd);
SLP_TREE_REF_COUNT (even)++;
SLP_TREE_REF_COUNT (odd)++;
SLP_TREE_REF_COUNT (vnode) = 1;
SLP_TREE_LANES (vnode) = SLP_TREE_LANES (rep);
gcc_assert (perm.length () == SLP_TREE_LANES (vnode));
/* Representation is set to that of the current node as the vectorizer
can't deal with VEC_PERMs with no representation, as would be the
case with invariants. */
SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (rep);
SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (rep);
return vnode;
}
class complex_mul_pattern : public complex_pattern
{
protected:
complex_mul_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
: complex_pattern (node, m_ops, ifn)
{
this->m_num_args = 2;
}
public:
void build (vec_info *);
static internal_fn
matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
vec<slp_tree> *);
static vect_pattern*
recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
static vect_pattern*
mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
{
return new complex_mul_pattern (node, m_ops, ifn);
}
};
/* Pattern matcher for trying to match complex multiply pattern in SLP tree
If the operation matches then IFN is set to the operation it matched
and the arguments to the two replacement statements are put in m_ops.
If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
This function matches the patterns shaped as:
double ax = (b[i+1] * a[i]);
double bx = (a[i+1] * b[i]);
c[i] = c[i] - ax;
c[i+1] = c[i+1] + bx;
If a match occurred then TRUE is returned, else FALSE. The initial match is
expected to be in OP1 and the initial match operands in args0. */
internal_fn
complex_mul_pattern::matches (complex_operation_t op,
slp_tree_to_load_perm_map_t *perm_cache,
slp_tree *node, vec<slp_tree> *ops)
{
internal_fn ifn = IFN_LAST;
if (op != MINUS_PLUS)
return IFN_LAST;
slp_tree root = *node;
/* First two nodes must be a multiply. */
auto_vec<slp_tree> muls;
if (vect_match_call_complex_mla (root, 0) != MULT_MULT
|| vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT)
return IFN_LAST;
/* Now operand2+4 may lead to another expression. */
auto_vec<slp_tree> left_op, right_op;
left_op.safe_splice (SLP_TREE_CHILDREN (muls[0]));
right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
if (linear_loads_p (perm_cache, left_op[1]).first == PERM_ODDEVEN)
return IFN_LAST;
bool neg_first = false;
bool conj_first_operand = false;
bool is_neg = vect_normalize_conj_loc (right_op, &neg_first);
if (!is_neg)
{
/* A multiplication needs to multiply agains the real pair, otherwise
the pattern matches that of FMS. */
if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN)
|| vect_normalize_conj_loc (left_op))
return IFN_LAST;
ifn = IFN_COMPLEX_MUL;
}
else if (is_neg)
{
if (!vect_validate_multiplication (perm_cache, left_op, right_op,
neg_first, &conj_first_operand,
false))
return IFN_LAST;
ifn = IFN_COMPLEX_MUL_CONJ;
}
if (!vect_pattern_validate_optab (ifn, *node))
return IFN_LAST;
ops->truncate (0);
ops->create (3);
complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]).first;
if (kind == PERM_EVENODD)
{
ops->quick_push (left_op[1]);
ops->quick_push (right_op[1]);
ops->quick_push (left_op[0]);
}
else if (kind == PERM_TOP)
{
ops->quick_push (left_op[1]);
ops->quick_push (right_op[1]);
ops->quick_push (left_op[0]);
}
else if (kind == PERM_EVENEVEN && !conj_first_operand)
{
ops->quick_push (left_op[0]);
ops->quick_push (right_op[0]);
ops->quick_push (left_op[1]);
}
else
{
ops->quick_push (left_op[0]);
ops->quick_push (right_op[1]);
ops->quick_push (left_op[1]);
}
return ifn;
}
/* Attempt to recognize a complex mul pattern. */
vect_pattern*
complex_mul_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
slp_tree *node)
{
auto_vec<slp_tree> ops;
complex_operation_t op
= vect_detect_pair_op (*node, true, &ops);
internal_fn ifn
= complex_mul_pattern::matches (op, perm_cache, node, &ops);
if (ifn == IFN_LAST)
return NULL;
return new complex_mul_pattern (node, &ops, ifn);
}
/* Perform a replacement of the detected complex mul pattern with the new
instruction sequences. */
void
complex_mul_pattern::build (vec_info *vinfo)
{
slp_tree node;
unsigned i;
FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node)
vect_free_slp_tree (node);
/* First re-arrange the children. */
SLP_TREE_CHILDREN (*this->m_node).reserve_exact (2);
SLP_TREE_CHILDREN (*this->m_node)[0] = this->m_ops[2];
SLP_TREE_CHILDREN (*this->m_node)[1] =
vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node);
SLP_TREE_REF_COUNT (this->m_ops[2])++;
/* And then rewrite the node itself. */
complex_pattern::build (vinfo);
}
/*******************************************************************************
* Pattern matching definitions
******************************************************************************/