tiny-flux-deep

1.3K
2
license:mit
by
AbstractPhil
Image Model
OTHER
New
1K downloads
Early-stage
Edge AI:
Mobile
Laptop
Server
Unknown
Mobile
Laptop
Server
Quick Summary

AI model with specialized capabilities.

Code Examples

Then call: image = generate("your prompt here")pythonpytorch
import torch
import torch.nn.functional as F

def flux_shift(t, s=3.0):
    """Flux-style timestep shifting - biases toward data end."""
    return s * t / (1 + (s - 1) * t)

def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
    """Euler sampling with classifier-free guidance."""
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    
    # Start from pure noise (t=0)
    x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
    img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
    
    # Rectified flow: integrate from t=0 (noise) to t=1 (data)
    timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device))
    
    for i in range(num_steps):
        t_curr = timesteps[i]
        t_next = timesteps[i + 1]
        dt = t_next - t_curr
        
        t_batch = t_curr.expand(1)
        
        # Conditional prediction
        v_cond = model(
            hidden_states=x,
            encoder_hidden_states=t5_emb,
            pooled_projections=clip_pooled,
            timestep=t_batch,
            img_ids=img_ids,
        )
        
        # Unconditional prediction (for CFG)
        v_uncond = model(
            hidden_states=x,
            encoder_hidden_states=torch.zeros_like(t5_emb),
            pooled_projections=torch.zeros_like(clip_pooled),
            timestep=t_batch,
            img_ids=img_ids,
        )
        
        # Classifier-free guidance
        v = v_uncond + cfg_scale * (v_cond - v_uncond)
        
        # Euler step
        x = x + v * dt
    
    return x  # [1, 4096, 16] - decode with VAE
Lune Expert Predictor (Trajectory Guidance)python
LuneExpertPredictor(
    time_dim=512,        # From timestep MLP
    clip_dim=768,        # CLIP pooled features
    expert_dim=1280,     # SD1.5 mid-block dimension (prediction target)
    hidden_dim=512,      # Internal MLP width
    output_dim=512,      # Output added to vec
    dropout=0.1,
)
Learnable balance between CLIP and T5python
from dataclasses import dataclass
from typing import Tuple

@dataclass
class TinyFluxConfig:
    # Core architecture
    hidden_size: int = 512
    num_attention_heads: int = 4
    attention_head_dim: int = 128  # hidden_size = heads × head_dim
    in_channels: int = 16          # VAE latent channels
    patch_size: int = 1
    joint_attention_dim: int = 768  # T5 embedding dim
    pooled_projection_dim: int = 768  # CLIP pooled dim
    num_double_layers: int = 15
    num_single_layers: int = 25
    mlp_ratio: float = 4.0
    axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)  # Must sum to head_dim
    
    # Lune expert predictor
    use_lune_expert: bool = True
    lune_expert_dim: int = 1280    # SD1.5 mid-block dim
    lune_hidden_dim: int = 512
    lune_dropout: float = 0.1
    
    # Sol attention prior
    use_sol_prior: bool = True
    sol_spatial_size: int = 8      # 8×8 spatial importance map
    sol_hidden_dim: int = 256
    sol_geometric_weight: float = 0.7  # 70% geometric, 30% learned
    
    # T5 enhancement
    use_t5_vec: bool = True
    t5_pool_mode: str = "attention"  # "attention", "mean", "cls"
    
    # Loss configuration
    lune_distill_mode: str = "cosine"  # "hard", "soft", "cosine", "huber"
    use_huber_loss: bool = True
    huber_delta: float = 0.1
    
    # Legacy compatibility
    guidance_embeds: bool = False
Validate checkpoint compatibilitypythontransformers
import torch
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Load text encoders
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)

clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)

# Load VAE
vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    subfolder="vae",
    torch_dtype=torch.bfloat16
).to("cuda")

# Load TinyFlux-Deep
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
exec(open(model_py).read())

config = TinyFluxConfig()
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
model.load_state_dict(weights, strict=False)
model.eval()

def encode_prompt(prompt):
    """Encode prompt with both T5 and CLIP."""
    # T5
    t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length", 
                              max_length=77, truncation=True).to("cuda")
    with torch.no_grad():
        t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
    
    # CLIP
    clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
                                  max_length=77, truncation=True).to("cuda")
    with torch.no_grad():
        clip_out = clip_model(**clip_tokens)
        clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
    
    return t5_emb, clip_pooled

def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
    """
    Euler sampling for rectified flow.
    
    Flow: x_t = (1-t)*noise + t*data
    Integrate from t=0 (noise) to t=1 (data)
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    t5_emb, clip_pooled = encode_prompt(prompt)
    
    # Null embeddings for CFG
    t5_null, clip_null = encode_prompt("")
    
    # Start from pure noise (t=0)
    x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
    img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
    
    # Rectified flow: 0 → 1 with Flux shift
    def flux_shift(t, s=3.0):
        return s * t / (1 + (s - 1) * t)
    
    timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda"))
    
    with torch.no_grad():
        for i in range(num_steps):
            t = timesteps[i].expand(1)
            dt = timesteps[i + 1] - timesteps[i]  # Positive
            
            # Conditional
            v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
            
            # Unconditional
            v_uncond = model(x, t5_null, clip_null, t, img_ids)
            
            # CFG
            v = v_uncond + cfg_scale * (v_cond - v_uncond)
            
            # Euler step
            x = x + v * dt
    
    # Decode with VAE
    x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2)  # [B, C, H, W]
    x = x / vae.config.scaling_factor
    with torch.no_grad():
        image = vae.decode(x).sample
    
    # Convert to PIL
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image[0].permute(1, 2, 0).cpu().float().numpy()
    image = (image * 255).astype("uint8")
    
    from PIL import Image
    return Image.fromarray(image)

# Generate!
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
image.save("tiger.png")
Generate!python
def generate_batch(prompts, **kwargs):
    """Generate multiple images."""
    return [generate_image(p, **kwargs) for p in prompts]

images = generate_batch([
    "a red bird with blue beak",
    "a mountain landscape at sunset",
    "an astronaut riding a horse",
], num_steps=25, cfg_scale=4.0)

Deploy This Model

Production-ready deployment in minutes

Together.ai

Instant API access to this model

Fastest API

Production-ready inference API. Start free, scale to millions.

Try Free API

Replicate

One-click model deployment

Easiest Setup

Run models in the cloud with simple API. No DevOps required.

Deploy Now

Disclosure: We may earn a commission from these partners. This helps keep LLMYourWay free.