Archives dans août 2024

Comment utiliser Segment Anything 2 (SAM2) pour la segmentation d’images ?

Comment utiliser Segment Anything 2 (SAM2) pour la segmentation d’images ?

Segment Anything Model (SAM) est un modèle de segmentation d’images développé par Meta (anciennement Facebook). Pour utiliser SAM2, vous devez suivre plusieurs étapes, y compris l’installation des bibliothèques nécessaires, le chargement du modèle, et l’exécution de la segmentation sur vos images. Voici un guide général pour vous aider à démarrer :

Étape 1 : Installation des dépendances

  1. Installer Python et les bibliothèques nécessaires :
    • Assurez-vous d’avoir Python installé sur votre machine.
    • Installez les bibliothèques nécessaires, telles que PyTorch, OpenCV, et d’autres dépendances spécifiques à SAM2.
      pip install torch torchvision torchaudio
      pip install opencv-python
      pip install numpy
      

      Cloner le dépôt SAM2 :

      • Clonez le dépôt GitHub de SAM2 pour obtenir le code source et les modèles pré-entraînés.
      git clone https://github.com/facebookresearch/segment-anything.git
      cd segment-anything
      

      Installer les dépendances spécifiques à SAM2 :

    • pip install -e . -q

      Étape 2 : Charger le modèle

      1. Importer les bibliothèques nécessaires :
      2. import cv2
        import torch
        import base64
        
        import numpy as np
        from PIL import Image
        
        import matplotlib.pyplot as plt
        
        from sam2.build_sam import build_sam2
        from sam2.sam2_image_predictor import SAM2ImagePredictor
        from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

        Charger le modèle SAM2 :

      3. torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
        
        if torch.cuda.get_device_properties(0).major >= 8:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        CHECKPOINT = f"{HOME}/checkpoints/sam2_hiera_small.pt"
        CONFIG = "sam2_hiera_s.yaml"
        
        sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)

        Étape 3 : Préparer l’image

        1. Charger et préparer l’image :
          image_bgr = cv2.imread("/content/de1.PNG")
          image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
          
          sam2_result = mask_generator.generate(image_rgb)

          Étape 4 : Exécuter la segmentation

        2. mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
          sam2_result = mask_generator.generate(image_rgb)

          Étape 5 : Visualiser les résultats

          1. Afficher les masques de segmentation :
          2. # Process results to extract masks
            masks = [mask["segmentation"] for mask in sam2_result]
            
            # Create an overlay of the masks on the original image
            overlay_image = image_rgb.copy()
            
            # Assign colors to each mask
            colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]  # Add more colors if needed
            for i, mask in enumerate(masks):
                color = colors[i % len(colors)]  # Loop through colors if there are more masks than colors
                colored_mask = np.zeros_like(image_rgb, dtype=np.uint8)
                colored_mask[mask] = color
                overlay_image = cv2.addWeighted(overlay_image, 1, colored_mask, 0.5, 0)
            
            # Plot images using Matplotlib
            plt.figure(figsize=(10, 5))
            
            # Display original image
            plt.subplot(1, 2, 1)
            plt.imshow(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
            plt.title('Source Image')
            plt.axis('off')
            
            # Display overlay image
            plt.subplot(1, 2, 2)
            plt.imshow(overlay_image)
            plt.title('Segmented Image')
            plt.axis('off')
            
            plt.tight_layout()
            plt.show()

            Conclusion

            En suivant ces étapes, vous pouvez utiliser SAM2 pour segmenter des images. Assurez-vous de consulter la documentation officielle et les exemples fournis dans le dépôt GitHub pour des instructions plus détaillées et des options avancées.


1 2