🔨Wrangle unstructured AI data at scale
Datachain enables multimodal API calls and local AI inferences to run in parallel over many samples as chained operations. The resulting datasets can be saved, versioned, and sent directly to PyTorch and TensorFlow for training. Datachain can persist features of Python objects returned by AI models, and enables vectorized analytical operations over them.
The typical use cases are data curation, LLM analytics and validation, image segmentation, pose detection, and GenAI alignment. Datachain is especially helpful if batch operations can be optimized – for instance, when synchronous API calls can be parallelized or where an LLM API offers batch processing.
pip install datachain
Datachain is built by composing wrangling operations.
For example, let us consider the New Yorker Cartoon caption contest dataset, where cartoons are matched against the potential titles. Let us imagine we want to augment this dataset with synthetic scene descriptions coming from an AI model. The below code takes images from the cloud, and applies PaliGemma model to caption the first five of them and put the results in the column “scene”:
#
# pip install transformers
#
from datachain.lib.dc import Column, DataChain
from datachain.lib.file import File
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
images = DataChain.from_storage("gs://datachain-demo/newyorker_caption_contest/images", type="image")
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-mix-224")
processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-224")
def process(file: File) -> str:
image=file.read().convert("RGB")
inputs = processor(text="caption", images=image, return_tensors="pt")
generate_ids = model.generate(**inputs, max_new_tokens=100)
return processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
chain = (
images.limit(5)
.settings(cache=True)
.map(scene=lambda file: process(file), output = str)
.save()
)
Here is how we can view the results in a plot:
import matplotlib.pyplot as plt
import re
from textwrap import wrap
def trim_text(text):
match = re.search(r'[A-Z][^.]*\\.', text)
return match.group(0) if match else ''
images = chain.collect_one("file")
captions = chain.collect_one("scene")
_ , axes = plt.subplots(1, len(captions), figsize=(15, 5))
for ax, img, caption in zip(axes, images, captions):
ax.imshow(img.read(),cmap='gray')
ax.axis('off')
wrapped_caption = "\\n".join(wrap(trim_text(caption), 30))
ax.set_title(wrapped_caption, fontsize=6)
plt.show()