Spaces:
Running
Running
Commit
·
aea0528
1
Parent(s):
574628b
Add way of limiting the complexity of power operators
Browse files- julia/sr.jl +28 -0
- pysr/sr.py +5 -0
julia/sr.jl
CHANGED
|
@@ -599,6 +599,27 @@ mutable struct PopMember
|
|
| 599 |
|
| 600 |
end
|
| 601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
# Go through one simulated annealing mutation cycle
|
| 603 |
# exp(-delta/T) defines probability of accepting a change
|
| 604 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
|
@@ -652,6 +673,13 @@ function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
|
| 652 |
return PopMember(tree, beforeLoss)
|
| 653 |
end
|
| 654 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
if batching
|
| 656 |
afterLoss = scoreFuncBatch(tree)
|
| 657 |
else
|
|
|
|
| 599 |
|
| 600 |
end
|
| 601 |
|
| 602 |
+
# Check if any power operator is to the power of a complex expression
|
| 603 |
+
function deepPow(tree::Node)::Integer
|
| 604 |
+
if tree.degree == 0
|
| 605 |
+
return 0
|
| 606 |
+
elseif tree.degree == 1
|
| 607 |
+
return 0 + deepPow(tree.l)
|
| 608 |
+
else
|
| 609 |
+
if binops[tree.op] == pow
|
| 610 |
+
complexity_in_power = countNodes(tree.r)
|
| 611 |
+
is_deep_pow = (complexity_in_power > 1)
|
| 612 |
+
if is_deep_pow
|
| 613 |
+
return 1 + deepPow(tree.l)
|
| 614 |
+
else
|
| 615 |
+
return 0 + deepPow(tree.l)
|
| 616 |
+
end
|
| 617 |
+
else
|
| 618 |
+
return 0 + deepPow(tree.l) + deepPow(tree.r)
|
| 619 |
+
end
|
| 620 |
+
end
|
| 621 |
+
end
|
| 622 |
+
|
| 623 |
# Go through one simulated annealing mutation cycle
|
| 624 |
# exp(-delta/T) defines probability of accepting a change
|
| 625 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
|
|
|
| 673 |
return PopMember(tree, beforeLoss)
|
| 674 |
end
|
| 675 |
|
| 676 |
+
|
| 677 |
+
# Check for illegal functions
|
| 678 |
+
if limitPowComplexity && (deepPow(tree) > 0)
|
| 679 |
+
return PopMember(copyNode(prev), beforeLoss)
|
| 680 |
+
end
|
| 681 |
+
|
| 682 |
+
|
| 683 |
if batching
|
| 684 |
afterLoss = scoreFuncBatch(tree)
|
| 685 |
else
|
pysr/sr.py
CHANGED
|
@@ -87,6 +87,7 @@ def pysr(X=None, y=None, weights=None,
|
|
| 87 |
batchSize=50,
|
| 88 |
select_k_features=None,
|
| 89 |
warmupMaxsize=0,
|
|
|
|
| 90 |
threads=None, #deprecated
|
| 91 |
julia_optimization=3,
|
| 92 |
):
|
|
@@ -163,6 +164,9 @@ def pysr(X=None, y=None, weights=None,
|
|
| 163 |
a small number up to the maxsize (if greater than 0).
|
| 164 |
If greater than 0, says how many cycles before the maxsize
|
| 165 |
is increased.
|
|
|
|
|
|
|
|
|
|
| 166 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 167 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 168 |
(as strings).
|
|
@@ -275,6 +279,7 @@ const mutationWeights = [
|
|
| 275 |
{weightDoNothing:f}
|
| 276 |
]
|
| 277 |
const warmupMaxsize = {warmupMaxsize:d}
|
|
|
|
| 278 |
"""
|
| 279 |
|
| 280 |
op_runner = ""
|
|
|
|
| 87 |
batchSize=50,
|
| 88 |
select_k_features=None,
|
| 89 |
warmupMaxsize=0,
|
| 90 |
+
limitPowComplexity=False,
|
| 91 |
threads=None, #deprecated
|
| 92 |
julia_optimization=3,
|
| 93 |
):
|
|
|
|
| 164 |
a small number up to the maxsize (if greater than 0).
|
| 165 |
If greater than 0, says how many cycles before the maxsize
|
| 166 |
is increased.
|
| 167 |
+
:param limitPowComplexity: bool, whether to prevent pow from having
|
| 168 |
+
complex right arguments. I.e., 3.0^(x+y) becomes impossible,
|
| 169 |
+
but 3.0^x is possible.
|
| 170 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 171 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 172 |
(as strings).
|
|
|
|
| 279 |
{weightDoNothing:f}
|
| 280 |
]
|
| 281 |
const warmupMaxsize = {warmupMaxsize:d}
|
| 282 |
+
const limitPowComplexity = {"true" if limitPowComplexity else "false"}
|
| 283 |
"""
|
| 284 |
|
| 285 |
op_runner = ""
|