from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import os from tqdm import tqdm
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) model.to('cuda') model.eval()
data = [] root_dir = r'C:\Users\admin\Downloads\raw' for dirpath, dirnames, filenames in os.walk(root_dir): for filename in filenames: file_path = os.path.join(dirpath, filename) data.append(file_path) for file_path in tqdm(data): image = Image.open(file_path).convert('RGB') input_images = transform_image(image).unsqueeze(0).to('cuda')
# Prediction with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image.putalpha(mask)