NeuroBLAST-V3-SYNTH-EC-150000-JAX
27
license:apache-2.0
by
mkurman
Language Model
OTHER
New
27 downloads
Early-stage
Edge AI:
Mobile
Laptop
Server
Unknown
Mobile
Laptop
Server
Quick Summary
AI model with specialized capabilities.
Code Examples
Usagepythontransformers
import argparse
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax
def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50):
inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np")
original_input_ids = inputs["input_ids"]
batch_size, prompt_len = original_input_ids.shape
total_len = prompt_len + max_new_tokens
# Pad input_ids to total_len
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32)
input_ids = input_ids.at[:, :prompt_len].set(original_input_ids)
attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32)
params = model.params
@jax.jit
def model_step(params, input_ids, attention_mask, rng):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
return outputs.logits
rng = jax.random.PRNGKey(0)
print("Generating...")
current_len = prompt_len
printed_len = 0
for i in range(max_new_tokens):
rng, step_rng = jax.random.split(rng)
# Run model
logits = model_step(params, input_ids, attention_mask, step_rng)
# Get logits for the last valid token (current_len - 1)
next_token_logits = logits[:, current_len - 1, :]
# Sampling
scaled_logits = next_token_logits / temperature
next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1)
# Update input_ids
# We need to update the next position
input_ids = input_ids.at[:, current_len].set(next_token)
current_len += 1
# Streaming output
valid_ids = input_ids[0, :current_len]
current_text = tokenizer.decode(valid_ids, skip_special_tokens=False)
if i == 0:
pass
new_text = current_text[printed_len:]
if new_text:
print(new_text, end="", flush=True)
printed_len += len(new_text)
# Check EOS
if next_token[0] == tokenizer.eos_token_id:
break
valid_ids = input_ids[0, :current_len]
return tokenizer.decode(valid_ids, skip_special_tokens=False)
checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX"
print(f"Loading model from {checkpoint}...")
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
use_fast=True,
trust_remote_code=True,
)
print(f"Available devices: {jax.devices()}")
model = NeuroBLASTForCausalLMJax.from_pretrained(
checkpoint,
dtype=jnp.bfloat16,
trust_remote_code=True,
is_decoder=True,
)
generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128)
print("\nGenerated Text:")
print("-" * 20)
print(generated_text)
print("-" * 20)Deploy This Model
Production-ready deployment in minutes
Together.ai
Instant API access to this model
Production-ready inference API. Start free, scale to millions.
Try Free APIReplicate
One-click model deployment
Run models in the cloud with simple API. No DevOps required.
Deploy NowDisclosure: We may earn a commission from these partners. This helps keep LLMYourWay free.