生产者消费者模型
说明
- input_queue 数据输入队列
- output_queue 结果存放队列
- back_data 是否需要返回数据
- 相对于之前版本,不需要get和commit一一对应
- 需要实现的操作写在Model类中的forward方法中。
#!/usr/bin/python3
# **********************************************************
# * Author : leon
# * Email : 1778064763@qq.com
# * Create time : 2024-03-04 07:54
# * Filename : cpm.py
# * Description :
# **********************************************************
import multiprocessing
import queue
class Cpm:
def __init__(self, back_data=0):
self.input_queue = multiprocessing.Queue()
self.output_queue = multiprocessing.Queue()
self.worker_process = None
self.run = multiprocessing.Value('i', 0)
self.back_data = multiprocessing.Value('i', back_data)
self.commit_count = multiprocessing.Value('i', 0)
def stop(self):
self.run.value = 0
while not self.input_queue.empty():
self.input_queue.get(timeout=1)
if self.worker_process:
self.worker_process.join()
self.worker_process = None
def commit(self, input_data):
self.input_queue.put(input_data)
self.commit_count.value += 1
def start(self, model_load_method):
self.stop()
model = model_load_method()
if model is None:
return False
self.run.value = 1
self.worker_process = multiprocessing.Process(target=self.worker,
args=(model,))
self.worker_process.start()
return True
def get(self):
while self.run.value and self.back_data.value and self.commit_count.value > 0:
try:
result = self.output_queue.get(timeout=0.01)
self.commit_count.value -= 1
return result
except queue.Empty:
pass
return None
def worker(self, model):
while self.run.value:
try:
input_data = self.input_queue.get(timeout=1)
result = model.forward(input_data)
if self.back_data.value:
self.output_queue.put(result)
except queue.Empty:
pass
except Exception as e:
print("Error : ", e)
self.commit_count.value -= 1
del model
# Example usage
class Model:
def forward(self, input_data):
# Process the input data
if input_data == 10:
raise ValueError("Input Error")
result = input_data + 1
return result
def load_model():
return Model()
if __name__ == '__main__':
instance = Cpm(1)
instance.start(load_model)
instance.get()
input_data = 10
instance.commit(input_data)
instance.commit(input_data + 2)
instance.commit(input_data + 3)
result2 = instance.get()
result3 = instance.get()
result4 = instance.get()
print(result2)
print(result3)
print(result4)
# instance.stop()
网友评论