Upload folder using huggingface_hub
Browse files- scripts/model.py +6 -39
scripts/model.py
CHANGED
|
@@ -50,13 +50,9 @@ def improve_ocr_accuracy(img):
|
|
| 50 |
np.ndarray: The preprocessed image as a binary thresholded array.
|
| 51 |
|
| 52 |
"""
|
| 53 |
-
# Read image with PIL (for color preservation)
|
| 54 |
img = Image.open(img)
|
| 55 |
-
|
| 56 |
-
# Increase image size (can improve accuracy for small text)
|
| 57 |
img = img.resize((img.width * 4, img.height * 4))
|
| 58 |
|
| 59 |
-
# Increase contrast
|
| 60 |
enhancer = ImageEnhance.Contrast(img)
|
| 61 |
img = enhancer.enhance(2)
|
| 62 |
|
|
@@ -78,27 +74,19 @@ def create_ocr_outputs():
|
|
| 78 |
directory_path = os.getcwd() + '/data/processed/hand_labeled_tables/hand_labeled_tables'
|
| 79 |
|
| 80 |
for root, dirs, files in os.walk(directory_path):
|
| 81 |
-
# Print the current directory
|
| 82 |
-
print(f"Current directory: {root}")
|
| 83 |
-
|
| 84 |
-
# Print all subdirectories in the current directory
|
| 85 |
-
print("Subdirectories:")
|
| 86 |
for dir in dirs:
|
| 87 |
print(f"- {dir}")
|
| 88 |
|
| 89 |
-
# Print all files in the current directory
|
| 90 |
-
print("Files:")
|
| 91 |
for image_path in files:
|
| 92 |
print(f"- {image_path}")
|
| 93 |
full_path = os.path.join(root, image_path)
|
| 94 |
-
# Preprocess the image
|
| 95 |
preprocessed_image = improve_ocr_accuracy(full_path)
|
| 96 |
|
| 97 |
ocr_text = ocr_core(preprocessed_image)
|
| 98 |
with open(os.getcwd() + f"/data/processed/annotations/{image_path.split('.')[0]}.txt", 'wb') as f:
|
| 99 |
f.write(ocr_text.encode('utf-8'))
|
| 100 |
|
| 101 |
-
print("\n")
|
| 102 |
|
| 103 |
def prepare_dataset(ocr_dir, csv_dir, output_file):
|
| 104 |
"""
|
|
@@ -143,10 +131,7 @@ def tokenize_function(examples):
|
|
| 143 |
dict: A dictionary containing tokenized inputs and labels.
|
| 144 |
|
| 145 |
"""
|
| 146 |
-
# Tokenize the inputs
|
| 147 |
inputs = tokenizer(examples['prompt'], truncation=True, padding='max_length', max_length=1012)
|
| 148 |
-
|
| 149 |
-
# Create labels which are the same as input_ids
|
| 150 |
inputs['labels'] = inputs['input_ids'].copy()
|
| 151 |
return inputs
|
| 152 |
|
|
@@ -172,29 +157,23 @@ def calculate_metrics(model, tokenizer, texts, labels):
|
|
| 172 |
|
| 173 |
with torch.no_grad():
|
| 174 |
for text, label in zip(texts, labels):
|
| 175 |
-
# Tokenize input and label
|
| 176 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
| 177 |
label_ids = tokenizer.encode(label, return_tensors="pt")[0]
|
| 178 |
|
| 179 |
-
# Generate prediction
|
| 180 |
output = model.generate(input_ids, max_length=input_ids.shape[1] + len(label_ids), num_return_sequences=1)
|
| 181 |
predicted_ids = output[0][input_ids.shape[1]:]
|
| 182 |
|
| 183 |
-
# Convert ids to tokens
|
| 184 |
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_ids)
|
| 185 |
label_tokens = tokenizer.convert_ids_to_tokens(label_ids)
|
| 186 |
|
| 187 |
-
# Extend predictions and labels
|
| 188 |
all_predictions.extend(predicted_tokens)
|
| 189 |
all_labels.extend(label_tokens)
|
| 190 |
|
| 191 |
-
# Calculate loss
|
| 192 |
outputs = model(input_ids=input_ids, labels=label_ids.unsqueeze(0))
|
| 193 |
loss = outputs.loss
|
| 194 |
total_loss += loss.item() * len(label_ids)
|
| 195 |
total_tokens += len(label_ids)
|
| 196 |
|
| 197 |
-
# Calculate metrics
|
| 198 |
precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
| 199 |
recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
| 200 |
f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
|
@@ -211,10 +190,8 @@ if __name__ == '__main__':
|
|
| 211 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 212 |
print(f"Using device: {device}")
|
| 213 |
|
| 214 |
-
# Load a pretrained YOLOv8 model
|
| 215 |
model = YOLO('yolov8l.pt')
|
| 216 |
|
| 217 |
-
# Train the model on your custom dataset
|
| 218 |
results = model.train(
|
| 219 |
data='config.yaml',
|
| 220 |
epochs=1,
|
|
@@ -232,15 +209,11 @@ if __name__ == '__main__':
|
|
| 232 |
model.save(os.getcwd() + '/models/trained_yolov8.pt')
|
| 233 |
|
| 234 |
create_ocr_outputs()
|
| 235 |
-
|
| 236 |
-
# Usage
|
| 237 |
ocr_dir = os.getcwd() + '/data/processed/annotations'
|
| 238 |
csv_dir = os.getcwd() + '/data/processed/hand_labeled_tables'
|
| 239 |
output_file = 'dataset.jsonl'
|
| 240 |
prepare_dataset(ocr_dir, csv_dir, output_file)
|
| 241 |
|
| 242 |
-
|
| 243 |
-
# Load the dataset
|
| 244 |
dataset = load_dataset('json', data_files={'train': 'dataset.jsonl'})
|
| 245 |
dataset = dataset['train'].train_test_split(test_size=0.1)
|
| 246 |
|
|
@@ -262,13 +235,12 @@ if __name__ == '__main__':
|
|
| 262 |
weight_decay=0.01,
|
| 263 |
logging_dir='./logs',
|
| 264 |
logging_steps=10,
|
| 265 |
-
evaluation_strategy="epoch",
|
| 266 |
-
save_strategy="epoch",
|
| 267 |
-
load_best_model_at_end=True,
|
| 268 |
-
metric_for_best_model="eval_loss",
|
| 269 |
)
|
| 270 |
|
| 271 |
-
# Trainer
|
| 272 |
trainer = Trainer(
|
| 273 |
model=model,
|
| 274 |
args=training_args,
|
|
@@ -276,21 +248,16 @@ if __name__ == '__main__':
|
|
| 276 |
eval_dataset=tokenized_dataset['test'],
|
| 277 |
)
|
| 278 |
|
| 279 |
-
# Train the model
|
| 280 |
trainer.train()
|
| 281 |
|
| 282 |
-
# Evaluate the model
|
| 283 |
eval_results = trainer.evaluate()
|
| 284 |
print(f"Evaluation results: {eval_results}")
|
| 285 |
|
| 286 |
-
# Save the model
|
| 287 |
gpt_model.save_pretrained(os.getcwd() + '/models/gpt')
|
| 288 |
tokenizer.save_pretrained(os.getcwd() + '/models/gpt')
|
| 289 |
-
|
| 290 |
-
# Calculate metrics
|
| 291 |
precision, recall, f1 = calculate_metrics(gpt_model, tokenizer, dataset['test']['text'], dataset['test']['label'])
|
| 292 |
|
| 293 |
-
# Display metrics
|
| 294 |
print(f"Precision: {precision:.4f}")
|
| 295 |
print(f"Recall: {recall:.4f}")
|
| 296 |
print(f"F1 Score: {f1:.4f}")
|
|
|
|
| 50 |
np.ndarray: The preprocessed image as a binary thresholded array.
|
| 51 |
|
| 52 |
"""
|
|
|
|
| 53 |
img = Image.open(img)
|
|
|
|
|
|
|
| 54 |
img = img.resize((img.width * 4, img.height * 4))
|
| 55 |
|
|
|
|
| 56 |
enhancer = ImageEnhance.Contrast(img)
|
| 57 |
img = enhancer.enhance(2)
|
| 58 |
|
|
|
|
| 74 |
directory_path = os.getcwd() + '/data/processed/hand_labeled_tables/hand_labeled_tables'
|
| 75 |
|
| 76 |
for root, dirs, files in os.walk(directory_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
for dir in dirs:
|
| 78 |
print(f"- {dir}")
|
| 79 |
|
|
|
|
|
|
|
| 80 |
for image_path in files:
|
| 81 |
print(f"- {image_path}")
|
| 82 |
full_path = os.path.join(root, image_path)
|
|
|
|
| 83 |
preprocessed_image = improve_ocr_accuracy(full_path)
|
| 84 |
|
| 85 |
ocr_text = ocr_core(preprocessed_image)
|
| 86 |
with open(os.getcwd() + f"/data/processed/annotations/{image_path.split('.')[0]}.txt", 'wb') as f:
|
| 87 |
f.write(ocr_text.encode('utf-8'))
|
| 88 |
|
| 89 |
+
print("\n")
|
| 90 |
|
| 91 |
def prepare_dataset(ocr_dir, csv_dir, output_file):
|
| 92 |
"""
|
|
|
|
| 131 |
dict: A dictionary containing tokenized inputs and labels.
|
| 132 |
|
| 133 |
"""
|
|
|
|
| 134 |
inputs = tokenizer(examples['prompt'], truncation=True, padding='max_length', max_length=1012)
|
|
|
|
|
|
|
| 135 |
inputs['labels'] = inputs['input_ids'].copy()
|
| 136 |
return inputs
|
| 137 |
|
|
|
|
| 157 |
|
| 158 |
with torch.no_grad():
|
| 159 |
for text, label in zip(texts, labels):
|
|
|
|
| 160 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
| 161 |
label_ids = tokenizer.encode(label, return_tensors="pt")[0]
|
| 162 |
|
|
|
|
| 163 |
output = model.generate(input_ids, max_length=input_ids.shape[1] + len(label_ids), num_return_sequences=1)
|
| 164 |
predicted_ids = output[0][input_ids.shape[1]:]
|
| 165 |
|
|
|
|
| 166 |
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_ids)
|
| 167 |
label_tokens = tokenizer.convert_ids_to_tokens(label_ids)
|
| 168 |
|
|
|
|
| 169 |
all_predictions.extend(predicted_tokens)
|
| 170 |
all_labels.extend(label_tokens)
|
| 171 |
|
|
|
|
| 172 |
outputs = model(input_ids=input_ids, labels=label_ids.unsqueeze(0))
|
| 173 |
loss = outputs.loss
|
| 174 |
total_loss += loss.item() * len(label_ids)
|
| 175 |
total_tokens += len(label_ids)
|
| 176 |
|
|
|
|
| 177 |
precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
| 178 |
recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
| 179 |
f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
|
|
|
|
| 190 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 191 |
print(f"Using device: {device}")
|
| 192 |
|
|
|
|
| 193 |
model = YOLO('yolov8l.pt')
|
| 194 |
|
|
|
|
| 195 |
results = model.train(
|
| 196 |
data='config.yaml',
|
| 197 |
epochs=1,
|
|
|
|
| 209 |
model.save(os.getcwd() + '/models/trained_yolov8.pt')
|
| 210 |
|
| 211 |
create_ocr_outputs()
|
|
|
|
|
|
|
| 212 |
ocr_dir = os.getcwd() + '/data/processed/annotations'
|
| 213 |
csv_dir = os.getcwd() + '/data/processed/hand_labeled_tables'
|
| 214 |
output_file = 'dataset.jsonl'
|
| 215 |
prepare_dataset(ocr_dir, csv_dir, output_file)
|
| 216 |
|
|
|
|
|
|
|
| 217 |
dataset = load_dataset('json', data_files={'train': 'dataset.jsonl'})
|
| 218 |
dataset = dataset['train'].train_test_split(test_size=0.1)
|
| 219 |
|
|
|
|
| 235 |
weight_decay=0.01,
|
| 236 |
logging_dir='./logs',
|
| 237 |
logging_steps=10,
|
| 238 |
+
evaluation_strategy="epoch",
|
| 239 |
+
save_strategy="epoch",
|
| 240 |
+
load_best_model_at_end=True,
|
| 241 |
+
metric_for_best_model="eval_loss",
|
| 242 |
)
|
| 243 |
|
|
|
|
| 244 |
trainer = Trainer(
|
| 245 |
model=model,
|
| 246 |
args=training_args,
|
|
|
|
| 248 |
eval_dataset=tokenized_dataset['test'],
|
| 249 |
)
|
| 250 |
|
|
|
|
| 251 |
trainer.train()
|
| 252 |
|
|
|
|
| 253 |
eval_results = trainer.evaluate()
|
| 254 |
print(f"Evaluation results: {eval_results}")
|
| 255 |
|
|
|
|
| 256 |
gpt_model.save_pretrained(os.getcwd() + '/models/gpt')
|
| 257 |
tokenizer.save_pretrained(os.getcwd() + '/models/gpt')
|
| 258 |
+
|
|
|
|
| 259 |
precision, recall, f1 = calculate_metrics(gpt_model, tokenizer, dataset['test']['text'], dataset['test']['label'])
|
| 260 |
|
|
|
|
| 261 |
print(f"Precision: {precision:.4f}")
|
| 262 |
print(f"Recall: {recall:.4f}")
|
| 263 |
print(f"F1 Score: {f1:.4f}")
|