Spaces:
Runtime error
Runtime error
valentin urena
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,8 @@ import time
|
|
| 12 |
|
| 13 |
from chess_board import Game
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
| 17 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
@@ -22,13 +24,13 @@ MAX_NEW_TOKENS = 2048
|
|
| 22 |
DEFAULT_MAX_NEW_TOKENS = 128
|
| 23 |
|
| 24 |
# model_id = "hf://google/gemma-2b-keras"
|
| 25 |
-
model_id = "hf://google/gemma-2-2b-it"
|
| 26 |
|
| 27 |
# model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
| 28 |
|
| 29 |
|
| 30 |
-
model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
|
| 31 |
-
tokenizer = model.preprocessor.tokenizer
|
| 32 |
|
| 33 |
DESCRIPTION = """
|
| 34 |
# Gemma 2B
|
|
@@ -38,6 +40,16 @@ This game mode allows you to play a game against Gemma, the input must be in alg
|
|
| 38 |
If you need help learning algebraic notation ask Gemma!
|
| 39 |
"""
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# @spaces.GPU
|
| 42 |
def generate(
|
| 43 |
message: str,
|
|
@@ -45,13 +57,15 @@ def generate(
|
|
| 45 |
max_new_tokens: int = 1024,
|
| 46 |
) -> Iterator[str]:
|
| 47 |
|
| 48 |
-
input_ids = tokenizer.tokenize(message)
|
| 49 |
|
| 50 |
-
if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
response =
|
| 55 |
|
| 56 |
outputs = ""
|
| 57 |
|
|
|
|
| 12 |
|
| 13 |
from chess_board import Game
|
| 14 |
|
| 15 |
+
import google.generativeai as genai
|
| 16 |
+
|
| 17 |
|
| 18 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
| 19 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
|
|
| 24 |
DEFAULT_MAX_NEW_TOKENS = 128
|
| 25 |
|
| 26 |
# model_id = "hf://google/gemma-2b-keras"
|
| 27 |
+
# model_id = "hf://google/gemma-2-2b-it"
|
| 28 |
|
| 29 |
# model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
| 30 |
|
| 31 |
|
| 32 |
+
# model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
|
| 33 |
+
# tokenizer = model.preprocessor.tokenizer
|
| 34 |
|
| 35 |
DESCRIPTION = """
|
| 36 |
# Gemma 2B
|
|
|
|
| 40 |
If you need help learning algebraic notation ask Gemma!
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
|
| 44 |
+
user_secrets = UserSecretsClient()
|
| 45 |
+
api_key = user_secrets.get_secret("GEMINI_API_KEY")
|
| 46 |
+
genai.configure(api_key = api_key)
|
| 47 |
+
|
| 48 |
+
model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')
|
| 49 |
+
|
| 50 |
+
# Chat
|
| 51 |
+
chat = model.start_chat()
|
| 52 |
+
|
| 53 |
# @spaces.GPU
|
| 54 |
def generate(
|
| 55 |
message: str,
|
|
|
|
| 57 |
max_new_tokens: int = 1024,
|
| 58 |
) -> Iterator[str]:
|
| 59 |
|
| 60 |
+
# input_ids = tokenizer.tokenize(message)
|
| 61 |
|
| 62 |
+
# if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
|
| 63 |
+
# input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
|
| 64 |
+
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
| 65 |
+
|
| 66 |
+
# response = model.generate(message, max_length=max_new_tokens)
|
| 67 |
|
| 68 |
+
response = chat.send_message(message)
|
| 69 |
|
| 70 |
outputs = ""
|
| 71 |
|