Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8342393: Promote commutative vector IR node sharing #22863

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/hotspot/share/opto/node.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, Alibaba Group Holding Limited. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
Expand Down Expand Up @@ -832,7 +832,8 @@ class Node {
Flag_for_post_loop_opts_igvn = 1 << 15,
Flag_is_removed_by_peephole = 1 << 16,
Flag_is_predicated_using_blend = 1 << 17,
_last_flag = Flag_is_predicated_using_blend
Flag_is_commutative_vector_oper = 1 << 18,
_last_flag = Flag_is_commutative_vector_oper
};

class PD;
Expand Down Expand Up @@ -1069,6 +1070,8 @@ class Node {

bool is_predicated_using_blend() const { return (_flags & Flag_is_predicated_using_blend) != 0; }

bool is_commutative_vector_operation() const { return (_flags & Flag_is_commutative_vector_oper) != 0; }

// Used in lcm to mark nodes that have scheduled
bool is_scheduled() const { return (_flags & Flag_is_scheduled) != 0; }

Expand Down
41 changes: 30 additions & 11 deletions src/hotspot/share/opto/phaseX.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -62,6 +62,31 @@ NodeHash::NodeHash(Arena *arena, uint est_max_size) :
memset(_table,0,sizeof(Node*)*_max);
}

//-----------------------------------------------------------------------------

bool NodeHash::check_for_collision(const Node* n, const Node* k) {
// For commutative operations with same controlling edge
// perform order agnostic input edge comparison to promote
// node sharing.
uint req = n->req();
if (n->is_commutative_vector_operation()) {
assert(req == 3, "");
assert(k->is_commutative_vector_operation(), "");
if ((k->in(0) != n->in(0)) ||
((k->in(1) != n->in(1) || k->in(2) != n->in(2)) &&
(k->in(1) != n->in(2) || k->in(2) != n->in(1)))) {
return true;
}
} else {
for(uint i=0; i<req; i++) {
if(n->in(i) != k->in(i)) { // Different inputs?
return true;
}
}
}
return false;
}

//------------------------------hash_find--------------------------------------
// Find in hash table
Node *NodeHash::hash_find( const Node *n ) {
Expand All @@ -85,15 +110,12 @@ Node *NodeHash::hash_find( const Node *n ) {
while( 1 ) { // While probing hash table
if( k->req() == req && // Same count of inputs
k->Opcode() == op ) { // Same Opcode
for( uint i=0; i<req; i++ )
if( n->in(i)!=k->in(i)) // Different inputs?
goto collision; // "goto" is a speed hack...
if( n->cmp(*k) ) { // Check for any special bits
bool collision = check_for_collision(n, k);
if (collision == false && n->cmp(*k)) { // Check for any special bits
NOT_PRODUCT( _lookup_hits++ );
return k; // Hit!
}
}
collision:
NOT_PRODUCT( _look_probes++ );
key = (key + stride/*7*/) & (_max-1); // Stride through table with relative prime
k = _table[key]; // Get hashed value
Expand Down Expand Up @@ -137,15 +159,12 @@ Node *NodeHash::hash_find_insert( Node *n ) {
while( 1 ) { // While probing hash table
if( k->req() == req && // Same count of inputs
k->Opcode() == op ) { // Same Opcode
for( uint i=0; i<req; i++ )
if( n->in(i)!=k->in(i)) // Different inputs?
goto collision; // "goto" is a speed hack...
if( n->cmp(*k) ) { // Check for any special bits
bool collision = check_for_collision(n, k);
if (collision == false && n->cmp(*k)) { // Check for any special bits
NOT_PRODUCT( _lookup_hits++ );
return k; // Hit!
}
}
collision:
NOT_PRODUCT( _look_probes++ );
key = (key + stride) & (_max-1); // Stride through table w/ relative prime
k = _table[key]; // Get hashed value
Expand Down
1 change: 1 addition & 0 deletions src/hotspot/share/opto/phaseX.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class NodeHash : public AnyObj {
#ifdef ASSERT
~NodeHash(); // Unlock all nodes upon destruction of table.
#endif
bool check_for_collision(const Node* n, const Node* k);
Node *hash_find(const Node*);// Find an equivalent version in hash table
Node *hash_find_insert(Node*);// If not in table insert else return found node
void hash_insert(Node*); // Insert into hash table
Expand Down
94 changes: 73 additions & 21 deletions src/hotspot/share/opto/vectornode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ class VectorNode : public TypeNode {
return type()->ideal_reg();
}

virtual uint hash() const {
if (is_commutative_vector_operation()) {
assert(req() == 3, "");
return (uintptr_t)in(1) + (uintptr_t)in(2) + Opcode();
jatin-bhateja marked this conversation as resolved.
Show resolved Hide resolved
} else {
return Node::hash();
}
}

virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);

static VectorNode* scalar2vector(Node* s, uint vlen, BasicType bt, bool is_mask = false);
Expand Down Expand Up @@ -168,7 +177,7 @@ class SaturatingVectorNode : public VectorNode {
st->print("%s", _is_unsigned ? "{unsigned_vector_node}" : "{signed_vector_node}");
}
#endif
virtual uint hash() const { return Node::hash() + _is_unsigned; }
virtual uint hash() const { return VectorNode::hash() + _is_unsigned; }

bool is_unsigned() { return _is_unsigned; }
};
Expand All @@ -177,47 +186,59 @@ class SaturatingVectorNode : public VectorNode {
// Vector add byte
class AddVBNode : public VectorNode {
public:
AddVBNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
AddVBNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------AddVSNode--------------------------------------
// Vector add char/short
class AddVSNode : public VectorNode {
public:
AddVSNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
AddVSNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------AddVINode--------------------------------------
// Vector add int
class AddVINode : public VectorNode {
public:
AddVINode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
AddVINode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------AddVLNode--------------------------------------
// Vector add long
class AddVLNode : public VectorNode {
public:
AddVLNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
AddVLNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------AddVFNode--------------------------------------
// Vector add float
class AddVFNode : public VectorNode {
public:
AddVFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
AddVFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------AddVDNode--------------------------------------
// Vector add double
class AddVDNode : public VectorNode {
public:
AddVDNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
AddVDNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

Expand Down Expand Up @@ -387,7 +408,9 @@ class SubVLNode : public VectorNode {
// Vector saturating addition.
class SaturatingAddVNode : public SaturatingVectorNode {
public:
SaturatingAddVNode(Node* in1, Node* in2, const TypeVect* vt, bool is_unsigned) : SaturatingVectorNode(in1, in2, vt, is_unsigned) {}
SaturatingAddVNode(Node* in1, Node* in2, const TypeVect* vt, bool is_unsigned) : SaturatingVectorNode(in1, in2, vt, is_unsigned) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

Expand Down Expand Up @@ -419,23 +442,29 @@ class SubVDNode : public VectorNode {
// Vector multiply byte
class MulVBNode : public VectorNode {
public:
MulVBNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
MulVBNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------MulVSNode--------------------------------------
// Vector multiply short
class MulVSNode : public VectorNode {
public:
MulVSNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
MulVSNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------MulVINode--------------------------------------
// Vector multiply int
class MulVINode : public VectorNode {
public:
MulVINode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
MulVINode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

Expand All @@ -445,6 +474,7 @@ class MulVLNode : public VectorNode {
public:
MulVLNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
init_class_id(Class_MulVL);
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
bool has_int_inputs() const;
Expand All @@ -455,15 +485,19 @@ class MulVLNode : public VectorNode {
// Vector multiply float
class MulVFNode : public VectorNode {
public:
MulVFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
MulVFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//------------------------------MulVDNode--------------------------------------
// Vector multiply double
class MulVDNode : public VectorNode {
public:
MulVDNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
MulVDNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

Expand Down Expand Up @@ -605,14 +639,17 @@ class AbsVSNode : public VectorNode {
// Vector Min
class MinVNode : public VectorNode {
public:
MinVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
MinVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

class UMinVNode : public VectorNode {
public:
UMinVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2 ,vt) {
assert(is_integral_type(vt->element_basic_type()), "");
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};
Expand All @@ -621,14 +658,17 @@ class UMinVNode : public VectorNode {
// Vector Max
class MaxVNode : public VectorNode {
public:
MaxVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {}
MaxVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

class UMaxVNode : public VectorNode {
public:
UMaxVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {
assert(is_integral_type(vt->element_basic_type()), "");
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};
Expand Down Expand Up @@ -901,7 +941,9 @@ class RShiftCntVNode : public VectorNode {
// Vector and integer
class AndVNode : public VectorNode {
public:
AndVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
AndVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
virtual Node* Identity(PhaseGVN* phase);
};
Expand All @@ -918,7 +960,9 @@ class AndReductionVNode : public ReductionNode {
// Vector or byte, short, int, long as a reduction
class OrVNode : public VectorNode {
public:
OrVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
OrVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
virtual Node* Identity(PhaseGVN* phase);
};
Expand All @@ -935,7 +979,9 @@ class OrReductionVNode : public ReductionNode {
// Vector xor integer
class XorVNode : public VectorNode {
public:
XorVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
XorVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
};
Expand Down Expand Up @@ -1343,21 +1389,27 @@ class MaskAllNode : public VectorNode {
//--------------------------- Vector mask logical and --------------------------------
class AndVMaskNode : public AndVNode {
public:
AndVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : AndVNode(in1, in2, vt) {}
AndVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : AndVNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//--------------------------- Vector mask logical or ---------------------------------
class OrVMaskNode : public OrVNode {
public:
OrVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : OrVNode(in1, in2, vt) {}
OrVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : OrVNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

//--------------------------- Vector mask logical xor --------------------------------
class XorVMaskNode : public XorVNode {
public:
XorVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : XorVNode(in1, in2, vt) {}
XorVMaskNode(Node* in1, Node* in2, const TypeVect* vt) : XorVNode(in1, in2, vt) {
add_flag(Node::Flag_is_commutative_vector_oper);
}
virtual int Opcode() const;
};

Expand Down
Loading