Image Similarity Search with Vector Database¶
Use CassIO and Astra DB / Apache Cassandra® for similarity searches between image samples, powered by image embeddings and Vector Search.
NOTE: this uses Cassandra's "Vector Similarity Search" capability. Make sure you are connecting to a vector-enabled database for this demo.
In this notebook you will:
- Download CIFAR10 dataset comprising of image samples from the
torchvision
library. - Calculate image embedding vectors for them with SqueezeNet.
- Store the embedding vectors on a table in your Cassandra / Astra DB instance, using the
CassIO
library for ease of operation. - Run one or more searches for images similar to a provided sample.
- Start a simple web-app that exposes a image search feature.
Import packages¶
The CassIO object needed for this demo is the VectorTable
:
from cassio.vector import VectorTable
Other packages are needed for various tasks in this demo:
import os
from tqdm.auto import tqdm
import numpy as np
from IPython.display import display
# To preprocess image from various formats
from PIL import Image
# To download (image) files
import requests
# To load image dataset
import torch
import torchvision
from torchvision.transforms import (
Compose,
Resize,
CenterCrop,
ToTensor,
Normalize
)
# To build an interactive UI
import gradio
# To plot images
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
try:
from google.colab import files
IS_COLAB = True
except ModuleNotFoundError:
IS_COLAB = False
Connect to your DB¶
A database connection is needed to access Cassandra. The following assumes that a vector-search-capable Astra DB instance is available. Adjust as needed.
from cqlsession import getCQLSession, getCQLKeyspace
cqlMode = "astra_db" # "astra_db"/"local"
session = getCQLSession(mode=cqlMode)
keyspace = getCQLKeyspace(mode=cqlMode)
Load the Data¶
In this demo, you will use image samples from the Cifar10 dataset, a labeled collection of 60000 32x32 colour images in 10 classes, with 6000 images per class.
When preparing data for the pre-trained PyTorch models, it's essential to follow transformation (preprocess) steps for better results. The input images should have three channels and be in RGB format, with dimensions of at least 224 pixels in height and width (3 x H x W). Moreover, the pixel values undergo normalization using the mean values [0.485, 0.456, 0.406] and standard deviation values [0.229, 0.224, 0.225]. These mean and std values were calculated based on the ImageNet dataset. The dataset can be loaded from the torchvision library as follows:
preprocess = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
DATA_DIRECTORY = "data"
datasets = {
"CIFAR10": torchvision.datasets.CIFAR10(
DATA_DIRECTORY,
transform=preprocess,
download=True,
),
}
Files already downloaded and verified
Display few images¶
Each image belongs to a "class label". Take a look at the class labels for the images in the dataset (as found on the Cifar10 homepage):
These are labels with their respective names :
labels_dict = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
def show_random_images_from_full_dataset(dset, num_rows=2, num_cols=5):
"""Show random sample of images in PyTorch dataset."""
### get random sample of images and labels
indices = np.random.randint(0, high=len(dset)+1, size=num_rows*num_cols)
im_arrays = np.take(dset.data, indices, axis=0)
labels = [dset.classes[dset.targets[index]] for index in indices]
### plot sample
fig = plt.figure(figsize=(8, 8))
grid = ImageGrid(
fig,
111,
nrows_ncols=(num_rows, num_cols),
axes_pad=0.3)
for ax, im_array, label in zip(grid, im_arrays, labels):
ax.imshow(im_array)
ax.set_title(label)
ax.axis("off")
show_random_images_from_full_dataset(datasets["CIFAR10"])
Select only few images from each class¶
It generally takes around 1-2 minutes to process 500 images per class. You'll be loading 5000 images if GPU is available, otherwise a smaller number, to keep execution time under control.
GPU_AVAILABLE = torch.cuda.device_count() > 0
img_per_class = 100 # You can modify the no of images per class
n_classes = len(labels_dict)
sel_img_map = {i: [] for i in range(n_classes)}
labels_map = {i: [] for i in range(n_classes)}
sel_img_arr_map = {i: [] for i in range(n_classes)}
if GPU_AVAILABLE:
img_per_class = 500
print(f"\nGPU is available, hence loading {img_per_class} images per class")
else:
img_per_class = 300
print(f"\nCPU is available, hence loading {img_per_class} images per class")
# Iterate through each class and select a number of images
for (img, cl_label), img_arr in zip(datasets["CIFAR10"], datasets["CIFAR10"].data):
# are all classes completed?
if all(len(ims) >= img_per_class for ims in sel_img_map.values()):
break
if len(sel_img_map[cl_label]) < img_per_class:
# add this one
sel_img_map[cl_label].append(img)
labels_map[cl_label].append(cl_label)
sel_img_arr_map[cl_label].append(img_arr)
selected_images = []
labels = []
selected_img_array = []
for cl_label in range(n_classes):
selected_images += sel_img_map[cl_label]
labels += labels_map[cl_label]
selected_img_array += sel_img_arr_map[cl_label]
print("Total selected images:", len(selected_images))
CPU is available, hence loading 300 images per class Total selected images: 3000
Loading a Pretrained SqueezeNet Computer Vision Model¶
To utilize a pre-trained model provided by PyTorch, you can load it and use its output as an embedding for image data. This pre-trained model generates a sequence of floats with 1000 dimensions per input image.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.squeezenet1_1(
weights=torchvision.models.SqueezeNet1_1_Weights.IMAGENET1K_V1
).to(device).eval()
if GPU_AVAILABLE:
print("Loaded the image embedding model on the GPU.")
else:
print(
"Loaded the image embedding model on the CPU. Reduced defaults "
"will be used. Please consider upgrading to a GPU-powered "
"hardware for best experience."
)
Loaded the image embedding model on the CPU. Reduced defaults will be used. Please consider upgrading to a GPU-powered hardware for best experience.
Create a DB table through CassIO¶
When an instance of VectorTable
is created, CassIO takes care of the underlying database operations. An important parameter to supply is the embedding vector dimension (fixed, in this case, by the choice of the SqueezeNet model being used):
table_name = "image_table"
embedding_dimension = 1000
v_table = VectorTable(
session=session,
keyspace=keyspace,
table=table_name,
embedding_dimension=embedding_dimension,
primary_key_type="TEXT",
)
Preparing Vector Embeddings¶
This cell will encode the downloaded images and store the associated class of each image as metadata.
def get_vector_metadata(label_indices, class_list):
"""Return list of {"label": <class name>}."""
return [{"label": class_list[index]} for index in label_indices]
# Creating Vector IDs
# Each vector ID will have a prefix corresponding to CIFAR10
def get_vector_ids(batch_number, batch_size, prefix):
"""Return vector ids."""
start_index = batch_number
end_index = start_index + batch_size
ids = np.arange(start_index, end_index)
# create id based on prefix
# eg. if id == 5, prefix == "CIFAR10", then create "CIFAR10.5" as vector id.
ids_with_prefix = [f"{prefix}.{str(x)}" for x in ids]
return ids_with_prefix
def get_vectors_from_batch(data_processed, label_indices,batch_number, dataset):
"""Return list of tuples like (vector_id, vector_values, vector_metadata)."""
num_records = len(data_processed)
prefix = dataset.__class__.__name__
with torch.no_grad():
# generate image embeddings with PyTorch model
vector_values = model(data_processed).tolist()
# return respective IDs/metadata for each image embedding
vector_metadata = get_vector_metadata(label_indices, dataset.classes)
vector_ids = get_vector_ids(batch_number, num_records, prefix)
return list(zip(vector_ids, vector_values, vector_metadata))
dataset = datasets["CIFAR10"]
# Move the data to the respective device
preprocessed_data = torch.stack(selected_images).to(device)
Compute and store embedding vectors for images¶
This cell processes the image samples you just loaded. By working in batches, the embedding vectors are evaluated through the SqueezeNet model, and the result is stored in the Cassandra / Astra DB table by invoking the put_async
method of VectorTable
.
Note: this operation may take a couple of minutes or so (depending on whether you are on a CPU or GPU and your hardware specs).
BATCH_SIZE = 100
SAMPLES_TO_PROCESS = len(selected_images)
for i in tqdm(range(0, SAMPLES_TO_PROCESS, BATCH_SIZE)):
# Find end of batch
i_end = min(i + BATCH_SIZE, SAMPLES_TO_PROCESS)
# Generate embeddings for all the images in the batch
# with the corresponding vector id and metadata lists
batch_vectors = get_vectors_from_batch(
preprocessed_data[i:i_end],
labels[i:i_end],
i,
dataset,
)
# Insert all entries in the batch concurrently
futures = []
for vector_id, embedding, metadata in batch_vectors:
metadata = metadata
futures.append(v_table.put_async(
document=vector_id,
embedding_vector=embedding,
document_id=vector_id,
metadata=metadata,
ttl_seconds=None,
))
for future in futures:
future.result()
Run a similarity search¶
You will now obtain a new image file and search for samples similar to it.
Get the image of a ship as:
url = "https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship1.png"
r = requests.get(url, stream=True)
input_image = Image.open(r.raw)
print("A sample image")
input_image.resize((125, 125))
A sample image
In order to run the search, first get the embedding vector for the input file, then use it to run a similarity search on the CassIO VectorTable
:
def display_top_5_matches(matches, im_array_indices, im_label_list):
### plot matching results
fig = plt.figure(figsize=(8, 8))
grid = ImageGrid(
fig,
111,
nrows_ncols=(1, 5),
axes_pad=0.3
)
for ax,im_array_index,im_label in zip(grid, im_array_indices, im_label_list):
ax.imshow(selected_img_array[im_array_index])
ax.set_title(im_label)
ax.axis("off")
query_embedding = model(
preprocess(input_image).unsqueeze(0).to(device)
).tolist()[0]
matches = v_table.search(
embedding_vector=query_embedding,
top_k=5,
metric="cos",
metric_threshold=None,
)
im_array_indices = []
im_label_list = []
# Show the top 5 matching results
for match_i, match in enumerate(matches):
print(f"Match {match_i}: {match['document']} ", end="")
print(f"(label: {match['metadata']['label']}, ", end="")
print(f"distance: {match['distance']:.4f})")
image_id=match["document"].split(".")
im_array_indices.append(int(image_id[1]))
im_label_list.append(match["metadata"]["label"])
display_top_5_matches(matches, im_array_indices, im_label_list)
Match 0: CIFAR10.2409 (label: ship, distance: 0.9867) Match 1: CIFAR10.2653 (label: ship, distance: 0.9851) Match 2: CIFAR10.2401 (label: ship, distance: 0.9835) Match 3: CIFAR10.2524 (label: ship, distance: 0.9831) Match 4: CIFAR10.91 (label: airplane, distance: 0.9830)
Experiment with your own image file¶
In this section, you can supply any image file of your own to have a bit of fun.
While you're at it, do a bit of refactoring of the image processing steps:
def upload_image(filepath):
loaded_image = Image.open(filepath)
loaded_image = loaded_image.resize((125, 125))
return loaded_image
def image_similarity_search(supplied_image, top_k=5):
# get the embeddings for the image from the model
supplied_image_embedding = model(
preprocess(supplied_image).unsqueeze(0).to(device)
).tolist()[0]
# print((supplied_image_embedding))
matches = v_table.search(
embedding_vector=supplied_image_embedding,
top_k=top_k,
metric="cos",
metric_threshold=None,
)
return matches
Now try providing an image file of yours (skip this part if you want):
from PIL import Image
if IS_COLAB:
print("Please upload an image file from your computer: ")
uploaded = files.upload()
img_file_title = list(uploaded.keys())[0]
img_filepath = os.path.join(os.getcwd(), img_file_title)
else:
img_filepath = input("Please provide the full path to the image file: ")
supplied_image = upload_image(img_filepath)
upload_array_indices = []
upload_label_list = []
# Convert the image to RGB mode (removing alpha channel if present)
rgb_image = supplied_image.convert("RGB")
print("Your query image:")
display(rgb_image)
matches = image_similarity_search(rgb_image, top_k=3)
for match_i, match in enumerate(matches):
# Show the top 5 matching results
print(f"Match {match_i}: {match['document']} ", end="")
print(f"(label: {match['metadata']['label']}, ", end="")
print(f"distance: {match['distance']:.4f})")
image_id = match["document"].split(".")
upload_array_indices.append(int(image_id[1]))
upload_label_list.append(match["metadata"]["label"])
print("Similar images:")
display_top_5_matches(matches, upload_array_indices, upload_label_list)
Your query image:
Match 0: CIFAR10.67 (label: airplane, distance: 0.9473) Match 1: CIFAR10.15 (label: airplane, distance: 0.9454) Match 2: CIFAR10.125 (label: airplane, distance: 0.9443) Similar images:
Image Similarity Web App¶
The following cells set up and launch a simple application, powered by Gradio, demonstrating the image similarity search seen so far.
In its essence, Gradio makes it easy to expose a graphical interface around the following function, built using the components seen earlier, that accepts a user-provided image as input and returns a number of results from the library, found by similarity.
NUM_RESULT_WIDGETS = 5
def gradio_upload_image(input_image):
if input_image is not None:
found_images = []
# Convert the NumPy array to a PIL Image
pil_image = Image.fromarray(input_image)
matches = image_similarity_search(pil_image, top_k=5)
for match_i, match in enumerate(matches):
# Show the top 5 matching results
image_id = match["document"].split(".")
found_images.append(selected_img_array[int(image_id[1])])
plt.imshow(selected_img_array[int(image_id[1])])
else:
found_images = []
# pad the result in any case to the number of displayed widgets
return found_images + [None] * (NUM_RESULT_WIDGETS-len(found_images))
The next cell starts the Gradio app: click on the URL that will be displayed to open it.
Please keep in mind that:
- The cell will keep running as long as the UI is running. Interrupt the notebook kernel to regain control (e.g. to modify and re-launch, or execute other cells, etc).
- The cell output will give both a local URL to access the application, and an URL such as
https://<....>.gradio.live
to reach it from anywhere. Use the latter link from Colab and when sharing with others. (The link will expire after a certain time.) - The UI will also be shown within the notebook below the cell.
# Define the dimensions for output images
output_width = 300
output_height = 200
image_ui = gradio.Interface(
fn=gradio_upload_image,
inputs=gradio.components.Image(label="Your query image"),
outputs=[
gradio.components.Image(label=f"Search result #{output_i}",
width=output_width, height=output_height)
for output_i in range(NUM_RESULT_WIDGETS)
],
title="Image Similarity Search with CassIO & Vector Database",
)
image_ui.launch(share=True, debug=True)