Domain-Agnostic Foundation Models
Fine-Tuning the Segment Anything Model (SAM)
Like this post? Enter your email address to receive an API key:
The Segment Anything model (SAM) by Meta is a transformer-based foundation model trained on a vast amount of data (around 1 billion masks), which has given it a general concept of what an "object" is. It is therefore able to segment objects in new images with limited to no further training, however:
- The model has no notion of classes and only returns "masks".
- Despite its broad knowledge, SAM has not seen many domain-specific images and thus performs badly on non-generic tasks such as medical segmentation.
We have developed a proprietary variant of SAM which can be fine-tuned for your domain. In this tutorial, we will demonstrate how to fine-tune it using the Synativ SDK.
When should you fine-tune SAM?
☑️ You are working on a semantic segmentation task.
☑️ You have a large amount of data available to fine-tune SAM to your domain. If you only have a limited amount of data available it is generally better to start with one of Synativ's domain-specific foundation models.
Setting up Synativ
Make sure that you have installed the Synativ SDK before you authenticate with your API key:
from synativ.api import Synativ
synativ_api: Synativ = Synativ(api_key="{YOUR_API_KEY}")
Preparing your data
Dataset format
Before uploading your data to our cloud, make sure it is in the correct format. The file should look like this:
data
├── train
├── ground_truth
└── 000.png
└── 001.png
└── xxx.png
└── input
└── 000.png
└── 001.png
└── xxx.png
├── test
├── ground_truth
└── 000.png
└── 001.png
└── xxx.png
└── input
└── 000.png
└── 001.png
└── xxx.png
├── val # NOTE: optional
├── ground_truth
└── input
The images names can be different (.jpg
, .jpeg
, and .png
files are accepted), but every input
(image) needs a corresponding ground_truth
with the same filename. There is no limit to the number of images.
The labels should be one-channel (greyscale) images where each pixel value corresponds to the class. In case of binary segmentation, the labels should therefore contain 0
and 1
s, with 0 being the background. For multiclass segmentation, the pixel values should correspond to the class indices {0, 1, 2, ...}.
We will release a version of SAM soon which can be fine-tuned on unlabelled data. Let us know if you would like to be part of our beta users!
Uploading your data
To use proprietary data, you need to create a Synativ Dataset and give it a friendly name. It will automatically zip your data folder and upload it upon creation.
from synativ import Dataset
dataset: Dataset = synativ_api.create_dataset(
dataset_name="your_dataset",
dataset_dir="<path_to_your_dataset>"
)
This will return a Dataset with a few details, but most importantly a DatasetId that looks like this synativ-dataset-yyyyyyyy-yyyy-yyyy-yyyy-yyyyyyyyyy
. More info on Synativ Datasets can be found here.
Fine-tuning your model
In this tutorial, we use SAM-B(ase) as the foundation model, but Synativ also supports SAM-H(uge) for more complex tasks. The model comes with sensible default hyperparameters, but you can pass your own if needed (see below).
Starting your fine-tuning job
The arguments you pass to the fine-tuning function contains:
base_model
: which foundation model to use. Here, we will selectbase_model=sam_semantic_base
(sam_semantic_huge
is also possible).metadata
: a dictionary with hyperparameters you can modify. Their default values are the following:metadata = { "num_epochs": 50, "learning_rate": 8e-4, "n_classes": 2 }
Make sure
n_classes
is adapted to your task.
You can start fine-tuning by calling fine_tune
:
from synativ import Model
model: Model = synativ_api.fine_tune(
base_model="sam_semantic_base",
dataset_id=dataset.id,
metadata={}
)
This will initiate a fine-tuning job in our backend. Note that metadata
is a JSON string through which the user can set hyperparameters for the particular job. If left empty, the Synativ default parameters are used.
You will receive a Model
object as response:
Model(
creation_time='2023-08-07 13:16:02.992559',
checkpoint='',
metadata='{<used_parameters>}',
base_model='sam_semantic_base',
dataset_id='synativ-dataset-yyyyyyyy-yyyy-yyyy-yyyy-yyyyyyyyyy',
id='synativ-model-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx',
)
The SDK will always return the full list of configurable hyperparameters used in metadata
even if they were not overwritten by the user.
Monitoring your fine-tuning job
You can check the status of your inference job by calling get_model_status
with the respective InferenceId:
synativ_api.get_inference_status(inference_id=inference.id)
This will return a Status
object with one of the following:
Status(status='NOT_FOUND') ## Wrong inference id
Status(status='QUEUED') ## Job is queued
Status(status='SETTING_UP') ## Job is setting up
Status(status='DOWNLOADING_DATA') ## Downaloding data and fine-tuned model
Status(status='RUNNING_INFERENCE') ## Inference in progress
Status(status='SAVING_RESULTS') ## Saving inference results
Status(status='COMPLETED') ## Inference has completed
Status(status='FAILED') ## Inference has failed
As a rule of thumb, with sam_semantic_base
and the default "num_epochs": 50
, fine-tuning should take around a half an hour per hundred images.
Evaluating your fine-tuned model
Once the model is fine-tuned, you can evaluate the results by running inference on the test set that you uploaded earlier.
You can start inference by calling start_inference
:
inference: Inference = synativ_api.start_inference(
model_id=model.id,
dataset_id=dataset.id,
metadata={}
)
This will initiate an inference job in our backend. Note that metadata
is a JSON string through which the user can set hyperparamters for the particular job. If left empty, the Synativ default parameters are used.
You will receive an Inference
object as response:
Inference(
creation_time='2023-08-07 13:16:02.992559',
metadata='{<used_parameters>}',
model_id='synativ-model-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx',
dataset_id='synativ-dataset-yyyyyyyy-yyyy-yyyy-yyyy-yyyyyyyyyy',
id='synativ-inference-zzzzzzzz-zzzz-zzzz-zzzz-zzzzzzzzzz'
)
The SDK will always return the full list of configurable hyperparameters used in metadata
even if they were not overwritten by the user.
Although inference jobs generally are much faster, you can monitor it in the same way as your fine-tuning job. More info can be found here.
Using your fine-tuned SAM
You can use your fine-tuned SAM in two ways:
- Use it as a domain-specific starting point and fine-tune it for a particular application. You can continue fine-tuning it with the Synativ SDK by referencing the ModelId.
- Use Synativ to start hosting the model for inference.
You can download the model to your local disk by calling download_model
with the respective ModelId:
synativ_api.download_model(
model_id=model.id,
local_dir='<path_where_you_want_to_save_your_model>'
)
This may take a while as the model is usually a few GBs.