Skip to content

Commit fbaa63a

Browse files
author
backtime92
committed
train for SynthText data
1 parent d24fd68 commit fbaa63a

44 files changed

Lines changed: 1971 additions & 3546 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,14 @@
22
data/test/icdar2013/*.jpg
33
data/test/icdar2015/*.jpg
44
pretrain/*.pth
5-
__pycache__
5+
__init__.py
6+
craft.py
7+
craft_ic15_20k.pth
8+
craft_utils.py
9+
enlarge.py
10+
file_utils.py
11+
requirments.txt
12+
watershed.py
13+
basenet/*
14+
weights/*
15+
__pycache__

README.md

Lines changed: 3 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# CRAFT-Reimplementation
22
# Note:If you have any problems, please comment. Or you can join us weChat group. The QR code will update in issues #49 .
3-
# 更新:重构工作已经开始,由于重构全部都是工作之外的时间,所以重构及训练周期相对较长。目前进度新的高斯图生成方式已经实验完成,对比作者论文中的图,结合以前的实验探索和向作者的提问,这次是与作者一样的方式,预计在20号左右release合成数据的训练部分。由于公司连接github较卡,暂时无法上传高斯图生成这部分。
4-
# 更新2:目前已经在训练合成数据全监督这部分,同时弱监督部分同时在coding及实验。如果合成数据部分训练效果达到预期,预计周五会release这部分训练code。
3+
# 非常抱歉,一直没有继续维护这个工程,近期看到挺多人关注的,我预计11月底最晚12月初重新维护一下工程,由于当时实习期间整体工程能力不够导致工程相对较乱,这个重新维护会整理清楚。维护周期大概是两周时间,我会重新整理code以及重新训练同时上传训练的pretrain model。同时一些实验的关键和实验思路我会写上注释,欢迎到时关注。
54

65
## Reimplementation:Character Region Awareness for Text Detection Reimplementation based on Pytorch
76

@@ -19,83 +18,6 @@ The full paper is available at: https://arxiv.org/pdf/1904.01941.pdf
1918
5、4 nvidia GPUs(we use 4 nvidia titanX)
2019

2120

22-
## pre-trained model:
23-
`NOTE: There are old pre-trained models, I will upload the new results pre-trained models' link.`
24-
Syndata:[Syndata for baidu drive](https://pan.baidu.com/s/1MaznjE79JNS9Ld48ZtRefg) || [Syndata for google drive](https://drive.google.com/file/d/1FvqfBMZQJeZXGfZLl-840YXoeYK8CNwk/view?usp=sharing)
25-
Syndata+IC15:[Syndata+IC15 for baidu drive](https://pan.baidu.com/s/19lJRM6YWZXVkZ_aytsYSiQ) || [Syndata+IC15 for google
26-
drive](https://drive.google.com/file/d/1k17GuBG_omT91tJoIMSlLrorYbLXkq4z/view?usp=sharing)
27-
Syndata+IC13+IC17:[Syndata+IC13+IC17 for baidu drive](https://pan.baidu.com/s/1PTTzbM9XG0pNe5i-uL6Aag)|| [Syndata+IC13+IC17 for google drive](https://drive.google.com/open?id=1SkJEfaGYIq-eFxfzFVZb-cGdGWR8lPSi)
28-
29-
30-
## Training
31-
`Note: When you train the IC15-Data or MLT-Data, please see the annotation in data_loader.py line 92 and line 108-112.`
32-
33-
### Train for Syndata
34-
- download the Syndata(I will give the link)
35-
- change the path in basernet/vgg16_bn.py file:
36-
>` (/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth -> /your_path/vgg16_bn-6c64b313.pth).You can download the model here.`[baidu](https://pan.baidu.com/s/1_h5qdwYQAToDi_BB5Eg3vg)||[google](https://drive.google.com/open?id=1ZtvGpFQrbmEisB_GhmZb8UQOtvqY_-tW)
37-
- change the path in trainSyndata.py file:
38-
> `(1、/data/CRAFT-pytorch/SynthText -> /your_path/SynthText 2、/data/CRAFT-pytorch/synweights/synweights -> /your_path/real_weights)`
39-
- Run **`python trainSyndata.py`**
40-
41-
### Train for IC15 data based on Syndata pre-trained model
42-
- download the IC15 data, rename the image file and the gt file for ch4_training_images and ch4_training_localization_transcription_gt,respectively.
43-
- change the path in basernet/vgg16_bn.py file:
44-
> `(/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth -> /your_path/vgg16_bn-6c64b313.pth).You can download the model here.`[baidu](https://pan.baidu.com/s/1_h5qdwYQAToDi_BB5Eg3vg)||[google](https://drive.google.com/open?id=1ZtvGpFQrbmEisB_GhmZb8UQOtvqY_-tW)
45-
- change the path in trainic15data.py file:
46-
>` (1、/data/CRAFT-pytorch/SynthText -> /your_path/SynthText 2、/data/CRAFT-pytorch/real_weights -> /your_path/real_weights)`
47-
- change the path in trainic15data.py file:
48-
> `(1、/data/CRAFT-pytorch/1-7.pth -> /your_path/your_pre-trained_model_name 2、/data/CRAFT-pytorch/icdar1317 -> /your_ic15data_path/)`
49-
- Run **`python trainic15data.py`**
50-
51-
### Train for IC13+17 data based on Syndata pre-trained model
52-
53-
- download the MLT data, rename the image file and the gt file,respectively.
54-
- change the path in basernet/vgg16_bn.py file:
55-
> `(/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth -> /your_path/vgg16_bn-6c64b313.pth).You can download the model here.`[baidu](https://pan.baidu.com/s/1_h5qdwYQAToDi_BB5Eg3vg)||[google](https://drive.google.com/open?id=1ZtvGpFQrbmEisB_GhmZb8UQOtvqY_-tW)
56-
- change the path in trainic-MLT_data.py file:
57-
>` (1、/data/CRAFT-pytorch/SynthText -> /your_path/SynthText 2、savemodel path-> your savemodel path)`
58-
- change the path in trainic-MLT_data.py file:
59-
> `(1、/data/CRAFT-pytorch/1-7.pth -> /your_path/your_pre-trained_model_name 2、/data/CRAFT-pytorch/icdar1317 -> /your_ic15data_path/)`
60-
- Run **`python trainic-MLT_data.py`**
61-
62-
### If you want to train for weak supervised use our Syndate pre-trained model:
63-
1、You should first download the pre_trained model trained in the Syndata [baidu](https://pan.baidu.com/s/1MaznjE79JNS9Ld48ZtRefg)||[google](https://drive.google.com/file/d/1FvqfBMZQJeZXGfZLl-840YXoeYK8CNwk/view?usp=sharing).
64-
2、change the data path and pre-trained model path.
65-
3、run `python trainic15data.py`
66-
67-
68-
**This code supprts for Syndata and icdar2015, and we will release the training code for IC13 and IC17 as soon as possible.**
69-
70-
Methods |dataset |Recall |precision |H-mean
71-
----------------------------------------------|-------------|------------|---------------|------
72-
Syndata |ICDAR13 |71.93% |81.31% |76.33%
73-
Syndata+IC15 |ICDAR15 |76.12% |84.55% |80.11%
74-
Syndata+MLT(deteval) |ICDAR13 |86.81% |95.28% |90.85%
75-
Syndata+MLT(deteval)(new gaussian map method) |ICDAR13 |90.67% |94.56% |92.57%
76-
Syndata+IC15(new gaussian map method) |ICDAR15 |80.36% |84.25% |82.26%
77-
78-
### We have released the latest code with new gaussian map and random crop algorithm.
79-
**`Note:new gaussian map method can split the inference gaussian region score map`**
80-
`Sample:`
81-
<img src="https://github.com/backtime92/CRAFT-Reimplementation/blob/master/image/test3_score.jpg" width="384" height="512" /><img src="https://github.com/backtime92/CRAFT-Reimplementation/blob/master/image/test3_affinity.jpg" width="384" height="256" />
82-
83-
**`Note:We have solved the problem about detecting big word. Now we are training the model. And any issues or advice are welcome.`**
84-
85-
`Sample:`
86-
<img src="https://github.com/backtime92/CRAFT-Reimplementation/blob/master/image/test4_score.jpg" width="384" height="512" /><img src="https://github.com/backtime92/CRAFT-Reimplementation/blob/master/image/test4_affinity.jpg" width="384" height="256" />
87-
88-
###weChat QR code
89-
<img src="https://github.com/backtime92/CRAFT-Reimplementation/blob/master/image/wechatgroup.jpeg" width="150" height="150" />
90-
91-
92-
# Contributing to the project
93-
`We will release training code as soon as possible, and we have not yet reached the results given in the author's paper. Any pull requests or issues are welcome. We also hope that you could give us some advice for the project.`
94-
95-
# Acknowledgement
96-
Thanks for Youngmin Baek, Bado Lee, Dongyoon Han, Sangdoo Yun, Hwalsuk Lee excellent work and [code](https://github.com/clovaai/CRAFT-pytorch) for test. In this repo, we use the author repo's basenet and test code.
97-
98-
# License
99-
For commercial use, please contact us.
100-
21+
# to do list
22+
Release strong supervision training part in early December
10123

basenet/vgg16_bn.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,48 @@
11
from collections import namedtuple
22

33
import torch
4+
import torch.nn as nn
5+
import torch.nn.init as init
46
from torchvision import models
57
from torchvision.models.vgg import model_urls
6-
from torchutil import *
7-
import os
8-
9-
weights_folder = os.path.join(os.path.dirname(__file__) + '/../pretrain')
108

9+
def init_weights(modules):
10+
for m in modules:
11+
if isinstance(m, nn.Conv2d):
12+
init.xavier_uniform_(m.weight.data)
13+
if m.bias is not None:
14+
m.bias.data.zero_()
15+
elif isinstance(m, nn.BatchNorm2d):
16+
m.weight.data.fill_(1)
17+
m.bias.data.zero_()
18+
elif isinstance(m, nn.Linear):
19+
m.weight.data.normal_(0, 0.01)
20+
m.bias.data.zero_()
1121

1222
class vgg16_bn(torch.nn.Module):
13-
def __init__(self, pretrained=True, freeze=False):
23+
def __init__(self, pretrained=True, freeze=True):
1424
super(vgg16_bn, self).__init__()
1525
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
16-
# vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
17-
vgg_pretrained_features = models.vgg16_bn(pretrained=False)
18-
if pretrained:
19-
vgg_pretrained_features.load_state_dict(
20-
copyStateDict(torch.load(os.path.join(weights_folder, '/data/CRAFT-pytorch/vgg16_bn-6c64b313.pth'))))
21-
vgg_pretrained_features = vgg_pretrained_features.features
26+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
2227
self.slice1 = torch.nn.Sequential()
2328
self.slice2 = torch.nn.Sequential()
2429
self.slice3 = torch.nn.Sequential()
2530
self.slice4 = torch.nn.Sequential()
2631
self.slice5 = torch.nn.Sequential()
27-
for x in range(12): # conv2_2
32+
for x in range(12): # conv2_2
2833
self.slice1.add_module(str(x), vgg_pretrained_features[x])
29-
for x in range(12, 19): # conv3_3
34+
for x in range(12, 19): # conv3_3
3035
self.slice2.add_module(str(x), vgg_pretrained_features[x])
31-
for x in range(19, 29): # conv4_3
36+
for x in range(19, 29): # conv4_3
3237
self.slice3.add_module(str(x), vgg_pretrained_features[x])
33-
for x in range(29, 39): # conv5_3
38+
for x in range(29, 39): # conv5_3
3439
self.slice4.add_module(str(x), vgg_pretrained_features[x])
3540

3641
# fc6, fc7 without atrous conv
3742
self.slice5 = torch.nn.Sequential(
38-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
39-
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
40-
nn.Conv2d(1024, 1024, kernel_size=1)
43+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
44+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
45+
nn.Conv2d(1024, 1024, kernel_size=1)
4146
)
4247

4348
if not pretrained:
@@ -46,11 +51,11 @@ def __init__(self, pretrained=True, freeze=False):
4651
init_weights(self.slice3.modules())
4752
init_weights(self.slice4.modules())
4853

49-
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
54+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
5055

5156
if freeze:
52-
for param in self.slice1.parameters(): # only first conv
53-
param.requires_grad = False
57+
for param in self.slice1.parameters(): # only first conv
58+
param.requires_grad= False
5459

5560
def forward(self, X):
5661
h = self.slice1(X)

coordinates.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

0 commit comments

Comments
 (0)