NeuroBLAST-V3-SYNTH-EC-150000-JAX

27
license:apache-2.0
by
mkurman
Language Model
OTHER
New
27 downloads
Early-stage
Edge AI:
Mobile
Laptop
Server
Unknown
Mobile
Laptop
Server
Quick Summary

AI model with specialized capabilities.

Code Examples

Usagepythontransformers
import argparse
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax

def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50):
    inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np")
    original_input_ids = inputs["input_ids"]
    batch_size, prompt_len = original_input_ids.shape
    total_len = prompt_len + max_new_tokens
    
    # Pad input_ids to total_len
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32)
    input_ids = input_ids.at[:, :prompt_len].set(original_input_ids)
    
    attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32)
    params = model.params

    @jax.jit
    def model_step(params, input_ids, attention_mask, rng):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
        return outputs.logits

    rng = jax.random.PRNGKey(0)
    
    print("Generating...")
    current_len = prompt_len
    printed_len = 0
    
    for i in range(max_new_tokens):
        rng, step_rng = jax.random.split(rng)
        
        # Run model
        logits = model_step(params, input_ids, attention_mask, step_rng)
        
        # Get logits for the last valid token (current_len - 1)
        next_token_logits = logits[:, current_len - 1, :]
        
        # Sampling
        scaled_logits = next_token_logits / temperature
        next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1)
        
        # Update input_ids
        # We need to update the next position
        input_ids = input_ids.at[:, current_len].set(next_token)
        
        current_len += 1
        
        # Streaming output
        valid_ids = input_ids[0, :current_len]
        current_text = tokenizer.decode(valid_ids, skip_special_tokens=False)
        
        if i == 0:
             pass

        new_text = current_text[printed_len:]
        if new_text:
            print(new_text, end="", flush=True)
            printed_len += len(new_text)
        
        # Check EOS
        if next_token[0] == tokenizer.eos_token_id:
            break
            
    valid_ids = input_ids[0, :current_len]
    return tokenizer.decode(valid_ids, skip_special_tokens=False)


  checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX"

  print(f"Loading model from {checkpoint}...")
  tokenizer = AutoTokenizer.from_pretrained(
      checkpoint,
      use_fast=True,
      trust_remote_code=True,
  )

  print(f"Available devices: {jax.devices()}")

  model = NeuroBLASTForCausalLMJax.from_pretrained(
      checkpoint,
      dtype=jnp.bfloat16, 
      trust_remote_code=True,
      is_decoder=True,
  )
  
  generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128)
  
  print("\nGenerated Text:")
  print("-" * 20)
  print(generated_text)
  print("-" * 20)

Deploy This Model

Production-ready deployment in minutes

Together.ai

Instant API access to this model

Fastest API

Production-ready inference API. Start free, scale to millions.

Try Free API

Replicate

One-click model deployment

Easiest Setup

Run models in the cloud with simple API. No DevOps required.

Deploy Now

Disclosure: We may earn a commission from these partners. This helps keep LLMYourWay free.