You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
900 B
37 lines
900 B
from darknet import * |
|
|
|
def predict_tactic(net, s): |
|
prob = 0 |
|
d = c_array(c_float, [0.0]*256) |
|
tac = '' |
|
if not len(s): |
|
s = '\n' |
|
for c in s[:-1]: |
|
d[ord(c)] = 1 |
|
pred = predict(net, d) |
|
d[ord(c)] = 0 |
|
c = s[-1] |
|
while 1: |
|
d[ord(c)] = 1 |
|
pred = predict(net, d) |
|
d[ord(c)] = 0 |
|
pred = [pred[i] for i in range(256)] |
|
ind = sample(pred) |
|
c = chr(ind) |
|
prob += math.log(pred[ind]) |
|
if len(tac) and tac[-1] == '.': |
|
break |
|
tac = tac + c |
|
return (tac, prob) |
|
|
|
def predict_tactics(net, s, n): |
|
tacs = [] |
|
for i in range(n): |
|
reset_rnn(net) |
|
tacs.append(predict_tactic(net, s)) |
|
tacs = sorted(tacs, key=lambda x: -x[1]) |
|
return tacs |
|
|
|
net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0) |
|
t = predict_tactics(net, "+++++\n", 10) |
|
print t
|
|
|