Skip to content

Commit

Permalink
handle training state logical (#731)
Browse files Browse the repository at this point in the history
Co-authored-by: stew-ro <60453211+stew-ro@users.noreply.github.com>
  • Loading branch information
starain-pactera and stew-ro authored Nov 9, 2020
1 parent 6203e2c commit 569adf1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
18 changes: 10 additions & 8 deletions src/react/components/pages/editorPage/canvas.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1449,16 +1449,18 @@ export default class Canvas extends React.Component<ICanvasProps, ICanvasState>
if (selectedRegions.length > 0) {
const intersectionResult = _.intersection(selectedRegions, regions);
if (intersectionResult.length === 0) {
const relatedLabels = labels.find(label => selectedRegions.find(sr => sr.tags.find(t => t === label.label)));
if (relatedLabels && relatedLabels.confidence) {
const originLabel = this.props.selectedAsset!.labelData?.labels?.find(a => a.label === relatedLabels.label);
if (originLabel) {
relatedLabels.revised = true;
if(!relatedLabels.originValue){
relatedLabels.originValue = [...originLabel.value];
const relatedLabels = labels.filter(label => selectedRegions.find(sr => sr.tags.find(t => t === label.label)));
relatedLabels?.forEach(relatedLabel=>{
if (relatedLabel && relatedLabel.confidence) {
const originLabel = this.props.selectedAsset!.labelData?.labels?.find(a => a.label === relatedLabel.label);
if (originLabel) {
relatedLabel.revised = true;
if(!relatedLabel.originValue){
relatedLabel.originValue = [...originLabel.value];
}
}
}
}
});
}
}
regions.sort(this.compareRegionOrder);
Expand Down
21 changes: 15 additions & 6 deletions src/react/components/pages/train/trainPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {connect} from "react-redux";
import {RouteComponentProps} from "react-router-dom";
import {bindActionCreators} from "redux";
import url from "url";
import {IAsset} from "vott-react";
import {constants} from "../../../../common/constants";
import {isElectron} from "../../../../common/hostProcess";
import {interpolate, strings} from "../../../../common/strings";
Expand Down Expand Up @@ -316,15 +317,20 @@ export default class TrainPage extends React.Component<ITrainPageProps, ITrainPa
this.trainProcess().then(async (trainResult) => {
const assets = Object.values(this.props.project.assets);
const assetService = new AssetService(this.props.project);

const newAssets = {};
for (const asset of assets) {
const newAsset = _.cloneDeep(asset);
newAsset.labelingState = AssetLabelingState.Trained;

const metadata = await assetService.getAssetMetadata(newAsset);
if (metadata.labelData && metadata.labelData.labelingState !== AssetLabelingState.Trained) {
if (metadata.labelData && metadata.labelData.labels?.findIndex(label=>label.value?.length>0)>=0 && metadata.labelData.labelingState !== AssetLabelingState.Trained) {
metadata.labelData.labelingState = AssetLabelingState.Trained;
await assetService.save({ ...metadata });
metadata.asset.labelingState=AssetLabelingState.Trained;
const newMeta = await assetService.save({ ...metadata });
newAssets[asset.id] = newMeta.asset;
}
}
await this.props.actions.saveProject({...this.props.project, assets: newAssets},false,false);
this.setState((prevState, props) => ({
isTraining: false,
trainMessage: this.getTrainMessage(trainResult),
Expand Down Expand Up @@ -367,6 +373,7 @@ export default class TrainPage extends React.Component<ITrainPageProps, ITrainPa
error?.message !== undefined
? error.message : error,
});
throw error;
}
}

Expand Down Expand Up @@ -431,9 +438,11 @@ export default class TrainPage extends React.Component<ITrainPageProps, ITrainPa
) {
return;
}
assetMetadata.asset.labelingState = AssetLabelingState.ManuallyLabeled;
if (assetMetadata.labelData) {
assetMetadata.labelData.labelingState = AssetLabelingState.ManuallyLabeled;
if(assetMetadata.labelData?.labels?.findIndex(label=>label.value?.length>0)>=0){
assetMetadata.asset.labelingState = AssetLabelingState.ManuallyLabeled;
if (assetMetadata.labelData) {
assetMetadata.labelData.labelingState = AssetLabelingState.ManuallyLabeled;
}
}

assetMetadata.labelData?.labels?.forEach((label) => {
Expand Down
5 changes: 4 additions & 1 deletion src/services/assetService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ export class AssetService {
const ocrHeight = ocrExtent[3] - ocrExtent[1];
const result = [];
for (let i = 0; i < arr.length; i += 2) {
result.push((arr[i] / ocrWidth), (arr[i + 1] / ocrHeight));
result.push(...[
(arr[i] / ocrWidth),
(arr[i + 1] / ocrHeight),
]);
}
return result;
};
Expand Down

0 comments on commit 569adf1

Please sign in to comment.