diff --git a/testing.py b/testing.py index 7f8fa5a..612ace5 100644 --- a/testing.py +++ b/testing.py @@ -5,12 +5,24 @@ from flask import Flask, request app = Flask(__name__) -@app.route("/query",methods=["POST"]) -def generate_response(): + +@app.route("/query/",methods=["POST"]) +@app.route("/query/",methods=["POST"]) +def generate_response(req_model="phi"): + if req_model not in ['phi','granite']: return "Only models phi and granite are supported.", 400 + try: query = request.form['query'] + + + if req_model == 'granite': + model = model_granite + tokenizer = tokenizer_granite + else: + model = model_phi + tokenizer = tokenizer_phi messages = [ - {"role": "system", "content": "You are an evil and unhelpful AI assistant."}, + {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user","content": query} ] pipe = pipeline( @@ -34,13 +46,13 @@ def generate_response(): if __name__ == '__main__': torch.random.manual_seed(0) - model = AutoModelForCausalLM.from_pretrained( + model_phi = AutoModelForCausalLM.from_pretrained( "/home/fedora/microsoft/Phi-3-mini-4k-instruct", device_map="cuda", torch_dtype="auto", trust_remote_code=True, ) - model = AutoModelForCausalLM.from_pretrained( + model_granite = AutoModelForCausalLM.from_pretrained( "/home/fedora/granite-3b-code-instruct", device_map="cuda", torch_dtype="auto",