
RoMa를 사용하여 이미지 매칭해봤습니다. RoMa의 깃허브 저장소는 https://github.com/Parskatt/RoMa 입니다.
2025. 2. 20 최초작성
다음 포스트에 나온대로 conda 환경을 구성후 하는게 좋습니다.
Visual Studio Code와 Miniconda를 사용한 Python 개발 환경 만들기( Windows, Ubuntu, WSL2)
https://webnautes.tistory.com/1842
이제 XFeat를 테스트하기 위한 환경을 구성합니다.
conda create -n roma python=3.10
conda activate roma
git clone https://github.com/Parskatt/RoMa.git
cd RoMa
pip install -e .
필요한 파이썬 패키지가 설치된 후, RoMa를 사용하기 위해 필요한 romatch 패키지를 사용할 준비가 됩니다.
이제 RoMa 안에 있는 romatch 폴더를 원하는 곳으로 이동하여 사용하면 됩니다.
추가 패키지를 설치합니다.
pip install pyqt5
포스트 아래쪽에 있는 코드를 실행해보니 처음 실행할 땐 필요한 모델을 다운로드 하기 때문에 좀 시간이 걸립니다.
테스트는 다음 2장의 사진을 사용했습니다.
https://github.com/Parskatt/RoMa/blob/main/assets/toronto_A.jpg
https://github.com/Parskatt/RoMa/blob/main/assets/toronto_B.jpg
CPU로 동작해서인지 너무 느려서 CUDA 지원하는 Pytorch를 설치했습니다.
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
매칭 결과 나오기까지 속도가 빨라졌습니다.
실행결과 입니다. Load Image 1 버튼과 Load Image 2 버튼을 눌러 2장의 이미지를 차례로 선택한 후, Match Images 버튼을 누르고 잠시 기다리면 아래 스크린샷처럼 매칭 결과가 보입니다.
전체 코드입니다.
import sys import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' from PyQt5.QtWidgets import (QApplication, QMainWindow, QPushButton, QLabel, QVBoxLayout, QHBoxLayout, QWidget, QFileDialog, QProgressBar) from PyQt5.QtGui import QPixmap, QImage from PyQt5.QtCore import Qt, QThread, pyqtSignal import torch import numpy as np import cv2 from romatch import roma_outdoor import warnings import time from functools import lru_cache warnings.filterwarnings("ignore", category=UserWarning) class MatcherThread(QThread): finished = pyqtSignal(dict) progress = pyqtSignal(str) def __init__(self, roma_model, image1_path, image2_path, device): super().__init__() self.roma_model = roma_model self.image1_path = image1_path self.image2_path = image2_path self.device = device def run(self): try: self.progress.emit("Processing images...") result = self.process_images() self.finished.emit(result) except Exception as e: self.finished.emit({"error": str(e)}) def process_images(self): try: if torch.cuda.is_available(): torch.cuda.empty_cache() # 이미지 캐시 및 전처리 img1 = self.preprocess_image(self.image1_path) img2 = self.preprocess_image(self.image2_path) # ROMA 매칭 수행 self.progress.emit("Performing ROMA matching...") print("Performing ROMA matching...") warp, certainty = self.roma_model.match( self.image1_path, self.image2_path, device=self.device ) # 기본 샘플링 수행 matches, match_certainty = self.roma_model.sample(warp, certainty) # CUDA 텐서를 CPU로 이동 후 NumPy로 변환 if torch.is_tensor(matches): matches = matches.cpu().numpy() if torch.is_tensor(match_certainty): match_certainty = match_certainty.cpu().numpy() print(f"Initial matches: {len(matches)}") print(f"Certainty range: {match_certainty.min():.3f} - {match_certainty.max():.3f}") # 매칭 후 필터링 기준을 더 완화 if match_certainty is not None and len(match_certainty) > 0: # 가장 신뢰도가 높은 상위 500개 매칭점 선택 top_k = min(500, len(matches)) top_indices = np.argsort(match_certainty)[-top_k:] matches = matches[top_indices] match_certainty = match_certainty[top_indices] print(f"Selected top {len(matches)} matches") # RANSAC을 사용한 기하학적 검증 (선택사항) if len(matches) >= 4: src_pts = matches[:, :2] dst_pts = matches[:, 2:] print("\nDebug - Before RANSAC:") print(f"Source points shape: {src_pts.shape}") print(f"First few source points: \n{src_pts[:5]}") print(f"Destination points shape: {dst_pts.shape}") print(f"First few destination points: \n{dst_pts[:5]}") try: H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 3.0) if mask is not None: inliers = mask.ravel().astype(bool) matches = matches[inliers] match_certainty = match_certainty[inliers] print(f"Matches after RANSAC: {len(matches)}") print("Debug - After RANSAC filtering:") print(f"First few matches: \n{matches[:5]}") except Exception as e: print(f"RANSAC failed: {str(e)}") except Exception as e: import traceback error_msg = f"Error in process_images: {str(e)}\n{traceback.format_exc()}" print("\nError in process_images:") print(error_msg) self.progress.emit(f"Error: {str(e)}") return { "img1": None, "img2": None, "matches": np.array([]), "certainty": np.array([]) } if len(matches) == 0: print("No matches found after filtering") self.progress.emit("No matches found after filtering") return { "img1": img1, "img2": img2, "matches": np.array([]), "certainty": np.array([]) } print(f"Final matches: {len(matches)}") return { "img1": img1, "img2": img2, "matches": matches, "certainty": match_certainty } @staticmethod @lru_cache(maxsize=32) def preprocess_image(image_path): img = cv2.imread(image_path) if img is None: raise ValueError(f"Could not load image: {image_path}") return img # 원본 이미지 그대로 반환 class ImageMatcher(QMainWindow): def __init__(self): super().__init__() self.roma_model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.image1_path = None self.image2_path = None self.matcher_thread = None self.initUI() def initUI(self): self.setWindowTitle('Image Matcher') main_widget = QWidget() self.setCentralWidget(main_widget) layout = QVBoxLayout(main_widget) images_layout = QHBoxLayout() left_layout = QVBoxLayout() self.image1_label = QLabel() self.image1_label.setFixedSize(400, 400) self.image1_label.setAlignment(Qt.AlignCenter) self.image1_label.setStyleSheet("border: 2px solid black") self.load_image1_btn = QPushButton('Load Image 1') self.load_image1_btn.clicked.connect(self.load_image1) left_layout.addWidget(self.image1_label) left_layout.addWidget(self.load_image1_btn) right_layout = QVBoxLayout() self.image2_label = QLabel() self.image2_label.setFixedSize(400, 400) self.image2_label.setAlignment(Qt.AlignCenter) self.image2_label.setStyleSheet("border: 2px solid black") self.load_image2_btn = QPushButton('Load Image 2') self.load_image2_btn.clicked.connect(self.load_image2) right_layout.addWidget(self.image2_label) right_layout.addWidget(self.load_image2_btn) images_layout.addLayout(left_layout) images_layout.addLayout(right_layout) self.result_label = QLabel() self.result_label.setFixedSize(800, 400) self.result_label.setAlignment(Qt.AlignCenter) self.result_label.setStyleSheet("border: 2px solid black") self.progress_bar = QProgressBar() self.progress_bar.setTextVisible(True) self.progress_bar.hide() self.status_label = QLabel() self.status_label.setAlignment(Qt.AlignCenter) self.match_btn = QPushButton('Match Images') self.match_btn.clicked.connect(self.initialize_and_match) self.match_btn.setEnabled(False) layout.addLayout(images_layout) layout.addWidget(self.result_label) layout.addWidget(self.progress_bar) layout.addWidget(self.status_label) layout.addWidget(self.match_btn) self.setGeometry(100, 100, 850, 900) self.show() def initialize_model(self): if self.roma_model is None: try: status_msg = f"Initializing model on {self.device}..." print(status_msg) self.status_label.setText(status_msg) QApplication.processEvents() # 14의 배수로 해상도 설정 self.roma_model = roma_outdoor( device=self.device, coarse_res=322, # 14 * 23 ≈ 320 upsample_res=(644, 644) # 14 * 46 ≈ 640 ).to(self.device) success_msg = "Model initialized successfully" print(success_msg) self.status_label.setText(success_msg) return True except Exception as e: import traceback error_msg = f"Error initializing model: {str(e)}\n{traceback.format_exc()}" print("\nError in initialize_model:") print(error_msg) self.status_label.setText(f"Error initializing model: {str(e)}") return False return True def initialize_and_match(self): if not self.initialize_model(): return self.match_btn.setEnabled(False) self.progress_bar.setMaximum(0) self.progress_bar.show() self.matcher_thread = MatcherThread( self.roma_model, self.image1_path, self.image2_path, self.device ) self.matcher_thread.finished.connect(self.handle_matching_result) self.matcher_thread.progress.connect(self.update_progress) self.matcher_thread.start() def handle_matching_result(self, result): if "error" in result: error_msg = f"Error: {result['error']}" print(error_msg) self.status_label.setText(error_msg) else: self.visualize_matches(result) self.progress_bar.hide() self.match_btn.setEnabled(True) self.matcher_thread = None def update_progress(self, message): print(message) self.status_label.setText(message) def visualize_matches(self, result): try: img1 = result["img1"] img2 = result["img2"] matches = result["matches"] match_certainty = result["certainty"] # 원본 이미지로 매칭 결과 시각화 matched_img = self.draw_matches( img1, img2, matches, matches[:, :2], # 원본 좌표 사용 matches[:, 2:], # 원본 좌표 사용 match_certainty ) # 결과 이미지를 UI 크기에 맞게 축소 display_height = self.result_label.height() display_width = self.result_label.width() # 비율 유지하면서 크기 조정 img_height, img_width = matched_img.shape[:2] aspect_ratio = img_width / img_height if img_width / display_width > img_height / display_height: new_width = display_width new_height = int(display_width / aspect_ratio) else: new_height = display_height new_width = int(display_height * aspect_ratio) matched_img_resized = cv2.resize(matched_img, (new_width, new_height), interpolation=cv2.INTER_AREA) # 결과 표시 matched_img_rgb = cv2.cvtColor(matched_img_resized, cv2.COLOR_BGR2RGB) height, width = matched_img_rgb.shape[:2] bytes_per_line = 3 * width q_img = QImage(matched_img_rgb.tobytes(), width, height, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(q_img) self.result_label.setPixmap(pixmap) match_count = len(matches) if match_count > 0: avg_certainty = float(np.nanmean(match_certainty)) if match_certainty is not None else 0 status_msg = f"Matching completed - Found {match_count} matches (Avg certainty: {avg_certainty:.2f})" else: status_msg = "No reliable matches found. Try adjusting the matching parameters or using different images." print(status_msg) self.status_label.setText(status_msg) except Exception as e: import traceback error_msg = f"Error in visualization: {str(e)}\n{traceback.format_exc()}" print("\nError in visualization:") print(error_msg) self.status_label.setText(f"Error in visualization: {str(e)}") def draw_matches(self, img1, img2, matches, kpts1, kpts2, match_certainty=None): # 결과 이미지 생성 new_width = img1.shape[1] + img2.shape[1] new_height = max(img1.shape[0], img2.shape[0]) out = np.zeros((new_height, new_width, 3), dtype=np.uint8) # 두 이미지 붙이기 out[:img1.shape[0], :img1.shape[1]] = img1 out[:img2.shape[0], img1.shape[1]:] = img2 # 신뢰도에 따른 색상 설정 if match_certainty is not None: colors = [] for cert in match_certainty: if cert < 0.7: color = (0, 0, 255) # 빨강 (BGR) elif cert < 0.85: color = (0, 255, 0) # 초록 else: color = (255, 0, 0) # 파랑 colors.append(color) else: colors = [(0, 255, 0)] * len(kpts1) # 모두 초록색으로 print(f"Drawing {len(matches)} match lines...") # 매칭 라인과 점 그리기 for i in range(len(matches)): # 정규화된 좌표를 실제 이미지 좌표로 변환 x1 = int((matches[i][0] + 1) * img1.shape[1] / 2) y1 = int((matches[i][1] + 1) * img1.shape[0] / 2) x2 = int((matches[i][2] + 1) * img2.shape[1] / 2) y2 = int((matches[i][3] + 1) * img2.shape[0] / 2) # 좌표 디버깅 print(f"\nMatch {i}:") print(f"Original normalized - P1:({matches[i][0]:.2f}, {matches[i][1]:.2f}) P2:({matches[i][2]:.2f}, {matches[i][3]:.2f})") print(f"Image coordinates - P1:({x1}, {y1}) P2:({x2}, {y2})") # 두 번째 이미지의 x 좌표 조정 x2 += img1.shape[1] # 매칭 라인 그리기 (두께를 2로 증가) cv2.line(out, (x1, y1), (x2, y2), colors[i], 2, cv2.LINE_AA) # 매칭점 그리기 cv2.circle(out, (x1, y1), 4, colors[i], -1, cv2.LINE_AA) cv2.circle(out, (x2, y2), 4, colors[i], -1, cv2.LINE_AA) print("Match visualization completed") return out def load_image1(self): self.image1_path = self.load_image(self.image1_label) if self.image1_path: print(f"Loaded image 1: {self.image1_path}") self.check_enable_match() def load_image2(self): self.image2_path = self.load_image(self.image2_label) if self.image2_path: print(f"Loaded image 2: {self.image2_path}") self.check_enable_match() def load_image(self, label): try: file_name, _ = QFileDialog.getOpenFileName( self, "Open Image File", "", "Images (*.png *.xpm *.jpg *.bmp)" ) if file_name and os.path.exists(file_name): pixmap = QPixmap(file_name) if not pixmap.isNull(): scaled_pixmap = pixmap.scaled(label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) label.setPixmap(scaled_pixmap) return file_name return None except Exception as e: error_msg = f"Error loading image: {str(e)}" print(error_msg) return None def check_enable_match(self): should_enable = bool(self.image1_path) and bool(self.image2_path) self.match_btn.setEnabled(should_enable) if should_enable: print("Match button enabled - Ready to process images") else: print("Match button disabled - Please load both images") if __name__ == '__main__': app = QApplication(sys.argv) ex = ImageMatcher() sys.exit(app.exec_()) |
'OpenCV > OpenCV 강좌' 카테고리의 다른 글
칼만 필터로 웹캠 영상 속 파란원을 추적하는 OpenCV Python 예제 (0) | 2025.02.25 |
---|---|
딥러닝 모델을 사용하여 유사 이미지 그룹별로 묶어서 보여주는 PyQt5 예제 코드 (0) | 2025.02.21 |
칼만 필터를 사용하여 아치형 트랙 위를 반복적으로 굴러가는 파란색 공을 추적하는 간단한 테스트 (0) | 2025.02.14 |
칼만 필터를 사용하여 파란색 원을 추적하는 간단한 테스트 (0) | 2025.02.12 |
YoLo를 사용하여 영상에 사람이 있었던 총시간을 측정하는 OpenCV Python 예제 (1) | 2025.01.28 |