Ulyha commited on
Commit
e480398
·
verified ·
1 Parent(s): 889d17b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +626 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from typing import List, Tuple, Any
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn.functional as torch_functional
7
+ from PIL import Image, ImageDraw
8
+ from transformers import (
9
+ CLIPModel,
10
+ CLIPProcessor,
11
+ SamModel,
12
+ SamProcessor,
13
+ BlipForQuestionAnswering,
14
+ BlipProcessor,
15
+ pipeline,
16
+ )
17
+
18
+ MODEL_STORE = {}
19
+
20
+ def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
21
+ if not gallery_value:
22
+ return []
23
+
24
+ normalized_images: List[Image.Image] = []
25
+
26
+ for item in gallery_value:
27
+ if isinstance(item, Image.Image):
28
+ normalized_images.append(item)
29
+ continue
30
+
31
+ if isinstance(item, str):
32
+ try:
33
+ image_object = Image.open(item).convert("RGB")
34
+ normalized_images.append(image_object)
35
+ except Exception:
36
+ continue
37
+ continue
38
+
39
+ if isinstance(item, (list, tuple)) and item:
40
+ candidate = item[0]
41
+ if isinstance(candidate, Image.Image):
42
+ normalized_images.append(candidate)
43
+ continue
44
+
45
+ if isinstance(item, dict):
46
+ candidate = item.get("image") or item.get("value")
47
+ if isinstance(candidate, Image.Image):
48
+ normalized_images.append(candidate)
49
+ continue
50
+
51
+ return normalized_images
52
+
53
+
54
+ def get_vision_pipeline(model_key: str):
55
+ if model_key in MODEL_STORE:
56
+ return MODEL_STORE[model_key]
57
+
58
+ if model_key == "object_detection_conditional_detr":
59
+ vision_pipeline = pipeline(
60
+ task="object-detection",
61
+ model="microsoft/conditional-detr-resnet-50",
62
+ )
63
+ elif model_key == "object_detection_yolos_small":
64
+ vision_pipeline = pipeline(
65
+ task="object-detection",
66
+ model="hustvl/yolos-small",
67
+ )
68
+ elif model_key == "segmentation":
69
+ vision_pipeline = pipeline(
70
+ task="image-segmentation",
71
+ model="nvidia/segformer-b0-finetuned-ade-512-512",
72
+ )
73
+ elif model_key == "depth_estimation":
74
+ vision_pipeline = pipeline(
75
+ task="depth-estimation",
76
+ model="Intel/dpt-hybrid-midas",
77
+ )
78
+ elif model_key == "captioning_blip_base":
79
+ vision_pipeline = pipeline(
80
+ task="image-to-text",
81
+ model="Salesforce/blip-image-captioning-base",
82
+ )
83
+ elif model_key == "captioning_blip_large":
84
+ vision_pipeline = pipeline(
85
+ task="image-to-text",
86
+ model="Salesforce/blip-image-captioning-large",
87
+ )
88
+ elif model_key == "vqa_blip_base":
89
+ vision_pipeline = pipeline(
90
+ task="visual-question-answering",
91
+ model="Salesforce/blip-vqa-base",
92
+ )
93
+ elif model_key == "vqa_vilt_b32":
94
+ vision_pipeline = pipeline(
95
+ task="visual-question-answering",
96
+ model="dandelin/vilt-b32-finetuned-vqa",
97
+ )
98
+ else:
99
+ raise ValueError(f"Неизвестный тип модели: {model_key}")
100
+
101
+ MODEL_STORE[model_key] = vision_pipeline
102
+ return vision_pipeline
103
+
104
+
105
+ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
106
+ model_store_key_model = f"clip_model_{clip_key}"
107
+ model_store_key_processor = f"clip_processor_{clip_key}"
108
+
109
+ if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE:
110
+ if clip_key == "clip_large_patch14":
111
+ clip_name = "openai/clip-vit-large-patch14"
112
+ elif clip_key == "clip_base_patch32":
113
+ clip_name = "openai/clip-vit-base-patch32"
114
+ else:
115
+ raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}")
116
+
117
+ clip_model = CLIPModel.from_pretrained(clip_name)
118
+ clip_processor = CLIPProcessor.from_pretrained(clip_name)
119
+
120
+ MODEL_STORE[model_store_key_model] = clip_model
121
+ MODEL_STORE[model_store_key_processor] = clip_processor
122
+
123
+ clip_model = MODEL_STORE[model_store_key_model]
124
+ clip_processor = MODEL_STORE[model_store_key_processor]
125
+ return clip_model, clip_processor
126
+
127
+
128
+ def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
129
+ if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
130
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
131
+ blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
132
+ MODEL_STORE["blip_vqa_model"] = blip_model
133
+ MODEL_STORE["blip_vqa_processor"] = blip_processor
134
+
135
+ blip_model = MODEL_STORE["blip_vqa_model"]
136
+ blip_processor = MODEL_STORE["blip_vqa_processor"]
137
+ return blip_model, blip_processor
138
+
139
+
140
+ def get_sam_components() -> Tuple[SamModel, SamProcessor]:
141
+ if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
142
+ sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
143
+ sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
144
+ MODEL_STORE["sam_model"] = sam_model
145
+ MODEL_STORE["sam_processor"] = sam_processor
146
+
147
+ sam_model = MODEL_STORE["sam_model"]
148
+ sam_processor = MODEL_STORE["sam_processor"]
149
+ return sam_model, sam_processor
150
+
151
+
152
+ def detect_objects_on_image(image_object, model_key: str):
153
+ if image_object is None:
154
+ return None
155
+
156
+ try:
157
+ detector_pipeline = get_vision_pipeline(model_key)
158
+ detection_results = detector_pipeline(image_object)
159
+
160
+ drawer_object = ImageDraw.Draw(image_object)
161
+ for detection_item in detection_results:
162
+ box_data = detection_item["box"]
163
+ label_value = detection_item["label"]
164
+ score_value = detection_item["score"]
165
+
166
+ drawer_object.rectangle(
167
+ [
168
+ box_data["xmin"],
169
+ box_data["ymin"],
170
+ box_data["xmax"],
171
+ box_data["ymax"],
172
+ ],
173
+ outline="red",
174
+ width=3,
175
+ )
176
+ drawer_object.text(
177
+ (box_data["xmin"], box_data["ymin"]),
178
+ f"{label_value}: {score_value:.2f}",
179
+ fill="red",
180
+ )
181
+
182
+ return image_object
183
+ except Exception as e:
184
+ print(f"Ошибка: {str(e)}")
185
+ return None
186
+
187
+
188
+ def segment_image(image_object):
189
+ if image_object is None:
190
+ return None
191
+
192
+ try:
193
+ segmentation_pipeline = get_vision_pipeline("segmentation")
194
+ segmentation_results = segmentation_pipeline(image_object)
195
+ return segmentation_results[0]["mask"]
196
+ except Exception as e:
197
+ print(f"Ошибка: {str(e)}")
198
+ return None
199
+
200
+
201
+ def estimate_image_depth(image_object):
202
+ if image_object is None:
203
+ return None
204
+
205
+ try:
206
+ depth_pipeline = get_vision_pipeline("depth_estimation")
207
+ depth_output = depth_pipeline(image_object)
208
+
209
+ predicted_depth_tensor = depth_output["predicted_depth"]
210
+
211
+ if predicted_depth_tensor.ndim == 3:
212
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1)
213
+ elif predicted_depth_tensor.ndim == 2:
214
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0)
215
+ else:
216
+ raise ValueError(
217
+ f"Неожиданная размерность: {predicted_depth_tensor.shape}"
218
+ )
219
+
220
+ resized_depth_tensor = torch_functional.interpolate(
221
+ predicted_depth_tensor,
222
+ size=image_object.size[::-1],
223
+ mode="bicubic",
224
+ align_corners=False,
225
+ )
226
+
227
+ depth_array = resized_depth_tensor.squeeze().cpu().numpy()
228
+ max_value = float(depth_array.max())
229
+
230
+ if max_value <= 0.0:
231
+ return Image.new("L", image_object.size, color=0)
232
+
233
+ normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8")
234
+ depth_image = Image.fromarray(normalized_depth_array, mode="L")
235
+ return depth_image
236
+ except Exception as e:
237
+ print(f"Ошибка: {str(e)}")
238
+ return None
239
+
240
+
241
+ def generate_image_caption(image_object, model_key: str) -> str:
242
+ if image_object is None:
243
+ return "Загрузите изображение"
244
+
245
+ try:
246
+ caption_pipeline = get_vision_pipeline(model_key)
247
+ caption_result = caption_pipeline(image_object)
248
+ return caption_result[0]["generated_text"]
249
+ except Exception as e:
250
+ return f"Ошибка: {str(e)}"
251
+
252
+
253
+ def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
254
+ if image_object is None:
255
+ return "Загрузите изображение"
256
+
257
+ if not question_text.strip():
258
+ return "Введите вопрос"
259
+
260
+ try:
261
+ if model_key == "vqa_blip_base":
262
+ blip_model, blip_processor = get_blip_vqa_components()
263
+
264
+ inputs = blip_processor(
265
+ images=image_object,
266
+ text=question_text,
267
+ return_tensors="pt",
268
+ )
269
+
270
+ with torch.no_grad():
271
+ output_ids = blip_model.generate(**inputs)
272
+
273
+ decoded_answers = blip_processor.batch_decode(
274
+ output_ids,
275
+ skip_special_tokens=True,
276
+ )
277
+ answer_text = decoded_answers[0] if decoded_answers else ""
278
+
279
+ return answer_text or "Модель не смогла ответить"
280
+
281
+ vqa_pipeline = get_vision_pipeline(model_key)
282
+
283
+ vqa_result = vqa_pipeline(
284
+ image=image_object,
285
+ question=question_text,
286
+ )
287
+
288
+ top_item = vqa_result[0]
289
+ answer_text = top_item["answer"]
290
+ confidence_value = top_item["score"]
291
+
292
+ return f"{answer_text} (уверенность: {confidence_value:.3f})"
293
+ except Exception as e:
294
+ return f"Ошибка: {str(e)}"
295
+
296
+
297
+ def perform_zero_shot_classification(
298
+ image_object,
299
+ class_texts: str,
300
+ clip_key: str,
301
+ ) -> str:
302
+ if image_object is None:
303
+ return "Загрузите изображение"
304
+
305
+ try:
306
+ clip_model, clip_processor = get_clip_components(clip_key)
307
+
308
+ class_list = [
309
+ class_name.strip()
310
+ for class_name in class_texts.split(",")
311
+ if class_name.strip()
312
+ ]
313
+ if not class_list:
314
+ return "Укажите классы для классификации"
315
+
316
+ input_batch = clip_processor(
317
+ text=class_list,
318
+ images=image_object,
319
+ return_tensors="pt",
320
+ padding=True,
321
+ )
322
+
323
+ with torch.no_grad():
324
+ clip_outputs = clip_model(**input_batch)
325
+ logits_per_image = clip_outputs.logits_per_image
326
+ probability_tensor = logits_per_image.softmax(dim=1)
327
+
328
+ result_lines = ["Результаты:"]
329
+ for class_index, class_name in enumerate(class_list):
330
+ probability_value = probability_tensor[0][class_index].item()
331
+ result_lines.append(f"{class_name}: {probability_value:.4f}")
332
+
333
+ return "\n".join(result_lines)
334
+ except Exception as e:
335
+ return f"Ошибка: {str(e)}"
336
+
337
+
338
+ def retrieve_best_image(
339
+ gallery_value: Any,
340
+ query_text: str,
341
+ clip_key: str,
342
+ ) -> Tuple[str, Image.Image | None]:
343
+ image_list = _normalize_gallery_images(gallery_value)
344
+
345
+ if not image_list or not query_text.strip():
346
+ return "Загрузите изображения и введите запрос", None
347
+
348
+ try:
349
+ clip_model, clip_processor = get_clip_components(clip_key)
350
+
351
+ image_inputs = clip_processor(
352
+ images=image_list,
353
+ return_tensors="pt",
354
+ padding=True,
355
+ )
356
+ with torch.no_grad():
357
+ image_features = clip_model.get_image_features(**image_inputs)
358
+ image_features = image_features / image_features.norm(
359
+ dim=-1,
360
+ keepdim=True,
361
+ )
362
+
363
+ text_inputs = clip_processor(
364
+ text=[query_text],
365
+ return_tensors="pt",
366
+ padding=True,
367
+ )
368
+ with torch.no_grad():
369
+ text_features = clip_model.get_text_features(**text_inputs)
370
+ text_features = text_features / text_features.norm(
371
+ dim=-1,
372
+ keepdim=True,
373
+ )
374
+
375
+ similarity_tensor = image_features @ text_features.T
376
+ best_index_tensor = similarity_tensor.argmax()
377
+ best_index_value = best_index_tensor.item()
378
+ best_score_value = similarity_tensor[best_index_value].item()
379
+
380
+ description_text = (
381
+ f"Изображение #{best_index_value + 1} "
382
+ f"(схожесть: {best_score_value:.4f})"
383
+ )
384
+ return description_text, image_list[best_index_value]
385
+ except Exception as e:
386
+ return f"Ошибка: {str(e)}", None
387
+
388
+
389
+ def segment_image_with_sam_points(
390
+ image_object,
391
+ point_coordinates_list: List[List[int]],
392
+ ) -> Image.Image:
393
+ if image_object is None:
394
+ raise ValueError("Изображение не передано")
395
+
396
+ if not point_coordinates_list:
397
+ return Image.new("L", image_object.size, color=0)
398
+
399
+ try:
400
+ sam_model, sam_processor = get_sam_components()
401
+
402
+ batched_points: List[List[List[int]]] = [point_coordinates_list]
403
+ batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
404
+
405
+ sam_inputs = sam_processor(
406
+ image=image_object,
407
+ input_points=batched_points,
408
+ input_labels=batched_labels,
409
+ return_tensors="pt",
410
+ )
411
+
412
+ with torch.no_grad():
413
+ sam_outputs = sam_model(**sam_inputs, multimask_output=True)
414
+
415
+ processed_masks_list = sam_processor.image_processor.post_process_masks(
416
+ sam_outputs.pred_masks.squeeze(1).cpu(),
417
+ sam_inputs["original_sizes"].cpu(),
418
+ sam_inputs["reshaped_input_sizes"].cpu(),
419
+ )
420
+
421
+ batch_masks_tensor = processed_masks_list[0]
422
+
423
+ if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
424
+ return Image.new("L", image_object.size, color=0)
425
+
426
+ first_mask_tensor = batch_masks_tensor[0]
427
+ mask_array = first_mask_tensor.numpy()
428
+
429
+ binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
430
+
431
+ mask_image = Image.fromarray(binary_mask_array, mode="L")
432
+ return mask_image
433
+ except Exception as e:
434
+ print(f"Ошибка: {str(e)}")
435
+ return Image.new("L", image_object.size, color=0)
436
+
437
+
438
+ def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
439
+ if image_object is None:
440
+ return None
441
+
442
+ coordinates_text_clean = coordinates_text.strip()
443
+ if not coordinates_text_clean:
444
+ return Image.new("L", image_object.size, color=0)
445
+
446
+ point_coordinates_list: List[List[int]] = []
447
+
448
+ for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
449
+ raw_pair_clean = raw_pair.strip()
450
+ if not raw_pair_clean:
451
+ continue
452
+
453
+ parts = raw_pair_clean.split(",")
454
+ if len(parts) != 2:
455
+ continue
456
+
457
+ try:
458
+ x_value = int(parts[0].strip())
459
+ y_value = int(parts[1].strip())
460
+ except ValueError:
461
+ continue
462
+
463
+ point_coordinates_list.append([x_value, y_value])
464
+
465
+ if not point_coordinates_list:
466
+ return Image.new("L", image_object.size, color=0)
467
+
468
+ return segment_image_with_sam_points(image_object, point_coordinates_list)
469
+
470
+
471
+ def build_interface():
472
+ with gr.Blocks(title="Vision Processing Demo") as demo:
473
+ gr.Markdown("# Система обработки изображений")
474
+
475
+ with gr.Tab("Детекция объектов"):
476
+ object_input_image = gr.Image(label="Загрузите изображение", type="pil")
477
+ object_model_selector = gr.Dropdown(
478
+ choices=[
479
+ "object_detection_conditional_detr",
480
+ "object_detection_yolos_small",
481
+ ],
482
+ label="Модель",
483
+ value="object_detection_conditional_detr",
484
+ )
485
+ object_detect_button = gr.Button("Выполнить детекцию")
486
+ object_output_image = gr.Image(label="Результат")
487
+
488
+ object_detect_button.click(
489
+ fn=detect_objects_on_image,
490
+ inputs=[object_input_image, object_model_selector],
491
+ outputs=object_output_image,
492
+ )
493
+
494
+ with gr.Tab("Сегментация"):
495
+ segmentation_input_image = gr.Image(label="Загрузите изображение", type="pil")
496
+ segmentation_button = gr.Button("Запустить сегментацию")
497
+ segmentation_output_image = gr.Image(label="Маска")
498
+
499
+ segmentation_button.click(
500
+ fn=segment_image,
501
+ inputs=segmentation_input_image,
502
+ outputs=segmentation_output_image,
503
+ )
504
+
505
+ with gr.Tab("Оценка глубины"):
506
+ depth_input_image = gr.Image(label="Загрузите изображение", type="pil")
507
+ depth_button = gr.Button("Оценить глубину")
508
+ depth_output_image = gr.Image(label="Карта глубины")
509
+
510
+ depth_button.click(
511
+ fn=estimate_image_depth,
512
+ inputs=depth_input_image,
513
+ outputs=depth_output_image,
514
+ )
515
+
516
+ with gr.Tab("Описание"):
517
+ caption_input_image = gr.Image(label="Загрузите изображение", type="pil")
518
+ caption_model_selector = gr.Dropdown(
519
+ choices=[
520
+ "captioning_blip_base",
521
+ "captioning_blip_large",
522
+ ],
523
+ label="Модель",
524
+ value="captioning_blip_base",
525
+ )
526
+ caption_button = gr.Button("Создать описание")
527
+ caption_output_text = gr.Textbox(label="Описание", lines=3)
528
+
529
+ caption_button.click(
530
+ fn=generate_image_caption,
531
+ inputs=[caption_input_image, caption_model_selector],
532
+ outputs=caption_output_text,
533
+ )
534
+
535
+ with gr.Tab("VQA"):
536
+ vqa_input_image = gr.Image(label="Загрузите изображение", type="pil")
537
+ vqa_question_text = gr.Textbox(label="Вопрос", lines=2)
538
+ vqa_model_selector = gr.Dropdown(
539
+ choices=[
540
+ "vqa_blip_base",
541
+ "vqa_vilt_b32",
542
+ ],
543
+ label="Модель",
544
+ value="vqa_blip_base",
545
+ )
546
+ vqa_button = gr.Button("Задать вопрос")
547
+ vqa_output_text = gr.Textbox(label="Ответ", lines=3)
548
+
549
+ vqa_button.click(
550
+ fn=answer_visual_question,
551
+ inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
552
+ outputs=vqa_output_text,
553
+ )
554
+
555
+ with gr.Tab("Zero-Shot"):
556
+ zero_shot_input_image = gr.Image(label="Загрузите изображение", type="pil")
557
+ zero_shot_classes_text = gr.Textbox(
558
+ label="Классы",
559
+ placeholder="Введите классы через запятую",
560
+ lines=2,
561
+ )
562
+ clip_model_selector = gr.Dropdown(
563
+ choices=[
564
+ "clip_large_patch14",
565
+ "clip_base_patch32",
566
+ ],
567
+ label="Модель",
568
+ value="clip_large_patch14",
569
+ )
570
+ zero_shot_button = gr.Button("Классифицировать")
571
+ zero_shot_output_text = gr.Textbox(label="Результаты", lines=8)
572
+
573
+ zero_shot_button.click(
574
+ fn=perform_zero_shot_classification,
575
+ inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
576
+ outputs=zero_shot_output_text,
577
+ )
578
+
579
+ with gr.Tab("Поиск"):
580
+ retrieval_dir = gr.File(
581
+ label="Загрузите папку",
582
+ file_count="directory",
583
+ file_types=["image"],
584
+ type="filepath",
585
+ )
586
+ retrieval_query_text = gr.Textbox(label="Текстовый запрос", lines=2)
587
+ retrieval_clip_selector = gr.Dropdown(
588
+ choices=[
589
+ "clip_large_patch14",
590
+ "clip_base_patch32",
591
+ ],
592
+ label="Модель",
593
+ value="clip_large_patch14",
594
+ )
595
+ retrieval_button = gr.Button("Найти изображение")
596
+ retrieval_output_text = gr.Textbox(label="Результат")
597
+ retrieval_output_image = gr.Image(label="Найденное изображение")
598
+
599
+ retrieval_button.click(
600
+ fn=retrieve_best_image,
601
+ inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
602
+ outputs=[retrieval_output_text, retrieval_output_image],
603
+ )
604
+
605
+ with gr.Tab("SAM"):
606
+ sam_input_image = gr.Image(label="Загрузите изображение", type="pil")
607
+ sam_coordinates_text = gr.Textbox(
608
+ label="Координаты точек",
609
+ placeholder="100,200; 300,400",
610
+ lines=3,
611
+ )
612
+ sam_button = gr.Button("Сегментировать по точкам")
613
+ sam_output_image = gr.Image(label="Маска")
614
+
615
+ sam_button.click(
616
+ fn=segment_image_with_sam_points_ui,
617
+ inputs=[sam_input_image, sam_coordinates_text],
618
+ outputs=sam_output_image,
619
+ )
620
+
621
+ return demo
622
+
623
+
624
+ if __name__ == "__main__":
625
+ interface = build_interface()
626
+ interface.launch(share=True, server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchaudio>=2.1.0
3
+ numpy>=1.24.0
4
+ transformers>=4.41.0
5
+ accelerate>=0.30.0
6
+ datasets>=2.18.0
7
+ soundfile>=0.12.1
8
+ librosa>=0.10.0
9
+ gradio>=4.0.0
10
+ gTTS>=2.5.1
11
+ pydantic==2.10.6