Skip to content

Commit

Permalink
Fix broken phase field test
Browse files Browse the repository at this point in the history
I don't understand the error very well. I've fixed it by once
again switching between lax.cond and np.where, which must be the
sixth time I've done this in the plasticity models. I'm still not
sure what the rules are for short circuiting in jax and how to be
certain that the false branch of the code won't be executed under
AD and/or jit. Comprehensive unit testing is in order.
  • Loading branch information
btalamini committed Mar 11, 2022
1 parent 38ec4da commit b7d1dd2
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions optimism/phasefield/PhaseFieldLorentzPlastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
STATE_EQPS = 0
STATE_PLASTIC_STRAIN = slice(1,1+9)
STATE_PLASTIC_DISTORTION = STATE_PLASTIC_STRAIN
STATE_FREE_ENERGY = 10
NUM_STATE_VARS = 11
NUM_STATE_VARS = 10


def create_material_model_functions(properties):
Expand Down Expand Up @@ -66,7 +65,7 @@ def create_material_model_functions(properties):

def compute_energy_density(dispGrad, phase, phaseGrad, internalVars):
elasticTrialStrain = compute_elastic_strain(dispGrad, internalVars)
return energy_density_generic(elasticTrialStrain, phase, phaseGrad, internalVars, props, hardeningModel, doUpdate=True) - internalVars[STATE_FREE_ENERGY]
return energy_density_generic(elasticTrialStrain, phase, phaseGrad, internalVars, props, hardeningModel, doUpdate=True)

def compute_output_energy_density(dispGrad, phase, phaseGrad, internalVars):
elasticStrain = compute_elastic_strain(dispGrad, internalVars)
Expand Down Expand Up @@ -109,16 +108,14 @@ def make_properties(E, nu, Gc, psiC, l, Y0):
def make_initial_state_finite_deformations(shape=(1,)):
eqps = 0.0
Fp = np.identity(3)
psi = 0.0
pointState = np.hstack((eqps, Fp.ravel(), psi))
pointState = np.hstack((eqps, Fp.ravel()))
return np.tile(pointState, shape)


def make_initial_state_small_deformations(shape=(1,)):
eqps = 0.0
plasticStrain = np.zeros((3,3))
psi = 0.0
pointState = np.hstack((eqps, plasticStrain.ravel(), psi))
pointState = np.hstack((eqps, plasticStrain.ravel()))
return np.tile(pointState, shape)


Expand All @@ -130,9 +127,7 @@ def compute_state_new_small_deformations(dispGrad, phase, phaseGrad, stateOld, p
eqpsNew = stateOld[STATE_EQPS] + stateInc[STATE_EQPS]
plasticStrainNew = stateOld[STATE_PLASTIC_STRAIN] + stateInc[STATE_PLASTIC_STRAIN]
elasticStrainNew = strain - plasticStrainNew
psi = compute_free_energy_density(elasticStrainNew, phase, phaseGrad,
eqpsNew, props, hardeningModel)
return np.hstack((eqpsNew, plasticStrainNew, psi))
return np.hstack((eqpsNew, plasticStrainNew))


def compute_state_new_finite_deformations(dispGrad, phase, phaseGrad, stateOld, props, hardeningModel):
Expand All @@ -143,9 +138,7 @@ def compute_state_new_finite_deformations(dispGrad, phase, phaseGrad, stateOld,
FpOld = np.reshape(stateOld[STATE_PLASTIC_STRAIN], (3,3))
FpNew = expm(stateInc[STATE_PLASTIC_STRAIN].reshape((3,3)))@FpOld
elasticStrainNew = elasticTrialStrain - stateInc[STATE_PLASTIC_STRAIN].reshape((3,3))
psi = compute_free_energy_density(elasticStrainNew, phase, phaseGrad,
eqpsNew, props, hardeningModel)
return np.hstack((eqpsNew, FpNew.ravel(), psi))
return np.hstack((eqpsNew, FpNew.ravel()))


def energy_density_generic(elStrain, phase, phaseGrad, state, props, hardeningModel, doUpdate):
Expand All @@ -154,11 +147,18 @@ def energy_density_generic(elStrain, phase, phaseGrad, state, props, hardeningMo
N = compute_flow_direction(elStrain)
trialStress = 2 * degradation(phase,props) * props[PROPS_MU] * np.tensordot(TensorMath.dev(elStrain), N)
flowStress = hardeningModel.compute_flow_stress(eqps)
isYielding = trialStress > flowStress

stateInc = np.where(isYielding,
update_state(elStrain, state, state, phase, props, hardeningModel),
np.zeros(NUM_STATE_VARS))

# Other way to introduce conditional
# stateInc = lax.cond(isYielding,
# lambda e: update_state(e, state, state, phase, props, hardeningModel),
# lambda e: np.zeros(NUM_STATE_VARS),
# elStrain)

stateInc = lax.cond(trialStress > flowStress,
lambda e: update_state(e, state, state, phase, props, hardeningModel),
lambda e: np.zeros(NUM_STATE_VARS),
elStrain)

else:
stateInc = np.zeros(NUM_STATE_VARS)
Expand All @@ -176,11 +176,16 @@ def compute_state_increment(elasticTrialStrain, phase, stateOld, props, hardenin
N = compute_flow_direction(elasticTrialStrain)
trialStress = 2 * degradation(phase,props)*props[PROPS_MU] * np.tensordot(TensorMath.dev(elasticTrialStrain), N)
flowStress = hardeningModel.compute_flow_stress(eqpsOld)
isYielding = trialStress > flowStress

stateInc = lax.cond(trialStress > flowStress,
lambda e: update_state(e, stateOld, stateOld, phase, props, hardeningModel),
lambda e: np.zeros(NUM_STATE_VARS),
elasticTrialStrain)
stateInc = np.where(isYielding,
update_state(elasticTrialStrain, stateOld, stateOld, phase, props, hardeningModel),
np.zeros(NUM_STATE_VARS))

# stateInc = lax.cond(trialStress > flowStress,
# lambda e: update_state(e, stateOld, stateOld, phase, props, hardeningModel),
# lambda e: np.zeros(NUM_STATE_VARS),
# elasticTrialStrain)

return stateInc

Expand Down Expand Up @@ -263,8 +268,7 @@ def radial_return_jvp(primals, vt):
DeltaEqps = eqps - eqpsOld
N = compute_flow_direction(elasticTrialStrain)
DeltaPlasticStrain = DeltaEqps*N
dummyFreeEnergyChange = 0.0
return np.hstack((DeltaEqps, DeltaPlasticStrain.ravel(), dummyFreeEnergyChange))
return np.hstack((DeltaEqps, DeltaPlasticStrain.ravel()))


def compute_elastic_linear_strain(dispGrad, plasticStrain):
Expand Down

0 comments on commit b7d1dd2

Please sign in to comment.