wglint / 4_sdxl

Stable Diffusion XL - Refiner

  • Public
  • 108 runs
  • GitHub

What do and how work this model

What do this model

This model name 4_SDXL use Stable Diffusion XL model 1.0 for generate picture. And you can use or not a refiner model for make better quality picture.

How this model work

Before start, we need to have Cog and Docker. For learn Cog, click her for Github Doc. But for start, use brew for install Cog :

brew install cog

After for this model, i use only 2 files :

All the code is in this repo Github.

Or, let check all code her :

cog.yaml

# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
  gpu: true
  cuda: "11.8"
  python_version: "3.9"
  system_packages:
    - "libgl1-mesa-glx"
    - "ffmpeg"
    - "libsm6"
    - "libxext6"
    - "wget"
  python_packages:
    - "diffusers==0.19.3"
    - "torch==2.0.1"
    - "transformers==4.31.0"
    - "invisible-watermark==0.2.0"
    - "accelerate==0.21.0"
    - "pandas==2.0.3"
    - "torchvision==0.15.2"
    - "numpy==1.25.1"
    - "pandas==2.0.3"
    - "fire==0.5.0"
    - "opencv-python>=4.1.0.25"
    - "mediapipe==0.10.2"

  run:
    - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
    - wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task

predict: "predict.py:Predictor"
image: "r8.im/wglint/4_sdxl"

predict.py

# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

from cog import BasePredictor, Input, Path

from typing import List

from diffusers import (
    DDIMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    PNDMScheduler,
    StableDiffusionXLImg2ImgPipeline,
    StableDiffusionXLInpaintPipeline,
)
import torch

SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
REFINDER_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"

class KarrasDPM:
    def from_config(config):
        return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)

SCHEDULERS = {
    "DDIM": DDIMScheduler,
    "DPMSolverMultistep": DPMSolverMultistepScheduler,
    "HeunDiscrete": HeunDiscreteScheduler,
    "KarrasDPM": KarrasDPM,
    "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
    "K_EULER": EulerDiscreteScheduler,
    "PNDM": PNDMScheduler,
}

class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""

        self.sdxl = DiffusionPipeline.from_pretrained(
            SDXL_MODEL,
            torch_dtype=torch.float16,
            variant="fp16", 
            use_safetensors=True
        ).to("cuda")

        self.refiner = DiffusionPipeline.from_pretrained(
            REFINDER_MODEL,
            vae=self.sdxl.vae,
            text_encoder_2=self.sdxl.text_encoder_2,
            torch_dtype=torch.float16,
            variant="fp16", 
            use_safetensors=True
        ).to("cuda")

    def predict(
        self,
        prompt: str = Input(description="Prompt to generate from", default="A studio photo of a rainbow coloured cat"),
        negative_prompt: str = Input(description="Prompt to generate from", default=""),
        scheduler: str = Input(description="Scheduler to use", default="DDIM", choices=[
            "DDIM",
            "DPMSolverMultistep",
            "HeunDiscrete",
            "KarrasDPM",
            "K_EULER_ANCESTRAL",
            "K_EULER",
            "PNDM"
        ]),
        width: int = Input(description="Width of the generated image", default=1024),
        height: int = Input(description="Height of the generated image", default=1024),
        guidance_scale: float = Input(description="Guidance scale", default=7.5),
        num_inteference_steps: int = Input(description="Number of interference steps", default=50),
        number_picture: int = Input(description="Number of picture to generate", default=1, ge=1, le=5),
        seed: int = Input(description="Random seed", default=1334),
        Refiner: bool = Input(description="Refine the generated image", default=False),
        Refiner_noise : float = Input(description="Refiner noise", default=0.8, ge=0.0, le=1.0),
    ) -> List[Path]:

        Generator = torch.Generator().manual_seed(seed)
        Parameters = {
            "prompt": [prompt] * number_picture,
            "negative_prompt": [negative_prompt] * number_picture,
            "width": width,
            "height": height,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inteference_steps,
            "generator": Generator,
        }

        self.sdxl.scheduler = SCHEDULERS[scheduler].from_config(
            self.sdxl.scheduler.config
        )
        self.refiner.scheduler = SCHEDULERS[scheduler].from_config(
            self.refiner.scheduler.config
        )

        if Refiner:
            print("Creating image with refiner model")
            image = self.sdxl(
                **Parameters,
                denoising_end=Refiner_noise,
                output_type="latent",
            ).images

            image_refiner = self.refiner(
                prompt = Parameters.get("prompt"),
                negative_prompt = Parameters.get("negative_prompt"),
                num_inference_steps = Parameters.get("num_inference_steps"),
                denoising_start=Refiner_noise,
                image = image,
            )
        else:
            print("Creating image without refiner model")
            image_refiner = self.sdxl(
                **Parameters
            )

        output_picture = []
        for i, sample in enumerate(image_refiner.images):
            output_path = f"/tmp/generated-{i}.png"
            sample.save(output_path)
            output_picture.append(Path(output_path))

        return output_picture

Let’s check my other model !